|
12 | 12 | from util import msg, profile, runparams, io |
13 | 13 |
|
14 | 14 |
|
15 | | -def doit(solver_name, problem_name, param_file, |
16 | | - other_commands=None, |
17 | | - comp_bench=False, reset_bench_on_fail=False, make_bench=False): |
18 | | - """The main driver to run pyro""" |
| 15 | +class Pyro(object): |
| 16 | + """ |
| 17 | + The main driver to run pyro. |
| 18 | +
|
| 19 | + Notes / TODO: |
| 20 | + * Further decouple some of the benchmarking stuff? |
| 21 | + * Should it be possible to pass in a problem function and initialise a Pyro/ |
| 22 | + Simulation object using that? That would also require some modifications |
| 23 | + of the Simulation class, as we currently pass the problem name to its |
| 24 | + constructor. I think this would make sense though if we were moving |
| 25 | + towards something that worked better as a Jupyter Notebook. |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__(self, solver_name, comp_bench=False, |
| 29 | + reset_bench_on_fail=False, make_bench=False): |
| 30 | + """ |
| 31 | + Constructor |
| 32 | +
|
| 33 | + Parameters |
| 34 | + ---------- |
| 35 | + solver_name : str |
| 36 | + Name of solver to use |
| 37 | + comp_bench : bool |
| 38 | + Are we comparing to a benchmark? |
| 39 | + reset_bench_on_fail : bool |
| 40 | + Do we reset the benchmark on fail? |
| 41 | + make_bench : bool |
| 42 | + Are we storing a benchmark? |
| 43 | + """ |
| 44 | + msg.bold('pyro ...') |
| 45 | + |
| 46 | + # import desired solver under "solver" namespace |
| 47 | + self.solver = importlib.import_module(solver_name) |
| 48 | + self.solver_name = solver_name |
| 49 | + self.comp_bench = comp_bench |
| 50 | + self.reset_bench_on_fail = reset_bench_on_fail |
| 51 | + self.make_bench = make_bench |
| 52 | + |
| 53 | + #------------------------------------------------------------------------- |
| 54 | + # runtime parameters |
| 55 | + #------------------------------------------------------------------------- |
| 56 | + |
| 57 | + # parameter defaults |
| 58 | + self.rp = runparams.RuntimeParameters() |
| 59 | + self.rp.load_params("_defaults") |
| 60 | + self.rp.load_params(solver_name + "/_defaults") |
| 61 | + |
| 62 | + self.tc = profile.TimerCollection() |
| 63 | + |
| 64 | + def initialize_problem(self, problem_name, param_file=None, param_dict=None, |
| 65 | + other_commands=None): |
| 66 | + """ |
| 67 | + Initialize the specific problem |
| 68 | +
|
| 69 | + Parameters |
| 70 | + ---------- |
| 71 | + problem_name : str |
| 72 | + Name of the problem |
| 73 | + param_file : str |
| 74 | + Filename containing problem's runtime parameters |
| 75 | + param_dict : dict |
| 76 | + Dictionary containing extra runtime parameters |
| 77 | + other_commands : str |
| 78 | + Other command line parameter options |
| 79 | + """ |
| 80 | + |
| 81 | + # problem-specific runtime parameters |
| 82 | + self.rp.load_params(self.solver_name + "/problems/_" + problem_name + ".defaults") |
| 83 | + |
| 84 | + # now read in the inputs file |
| 85 | + if param_dict is not None: |
| 86 | + for k, v in param_dict.items(): |
| 87 | + self.rp.params[k] = v |
| 88 | + if param_file is not None: |
| 89 | + if not os.path.isfile(param_file): |
| 90 | + # check if the param file lives in the solver's problems directory |
| 91 | + param_file = self.solver_name + "/problems/" + param_file |
| 92 | + if not os.path.isfile(param_file): |
| 93 | + msg.fail("ERROR: inputs file does not exist") |
| 94 | + |
| 95 | + self.rp.load_params(param_file, no_new=1) |
| 96 | + |
| 97 | + # and any commandline overrides |
| 98 | + if other_commands is not None: |
| 99 | + self.rp.command_line_params(other_commands) |
| 100 | + |
| 101 | + # write out the inputs.auto |
| 102 | + self.rp.print_paramfile() |
| 103 | + |
| 104 | + self.verbose = self.rp.get_param("driver.verbose") |
| 105 | + self.dovis = self.rp.get_param("vis.dovis") |
| 106 | + |
| 107 | + #------------------------------------------------------------------------- |
| 108 | + # initialization |
| 109 | + #------------------------------------------------------------------------- |
| 110 | + |
| 111 | + # initialize the Simulation object -- this will hold the grid and |
| 112 | + # data and know about the runtime parameters and which problem we |
| 113 | + # are running |
| 114 | + self.sim = self.solver.Simulation(self.solver_name, problem_name, self.rp, timers=self.tc) |
| 115 | + |
| 116 | + self.sim.initialize() |
| 117 | + self.sim.preevolve() |
| 118 | + |
| 119 | + def run_sim(self): |
| 120 | + """ |
| 121 | + Evolve entire simulation |
| 122 | + """ |
| 123 | + |
| 124 | + tm_main = self.tc.timer("main") |
| 125 | + tm_main.begin() |
| 126 | + |
| 127 | + plt.ion() |
| 128 | + |
| 129 | + self.sim.cc_data.t = 0.0 |
| 130 | + |
| 131 | + # output the 0th data |
| 132 | + basename = self.rp.get_param("io.basename") |
| 133 | + self.sim.write("{}{:04d}".format(basename, self.sim.n)) |
| 134 | + |
| 135 | + if self.dovis: |
| 136 | + plt.figure(num=1, figsize=(8, 6), dpi=100, facecolor='w') |
| 137 | + self.sim.dovis() |
| 138 | + |
| 139 | + while not self.sim.finished(): |
| 140 | + |
| 141 | + self.single_step() |
| 142 | + |
| 143 | + # final output |
| 144 | + if self.verbose > 0: |
| 145 | + msg.warning("outputting...") |
| 146 | + basename = self.rp.get_param("io.basename") |
| 147 | + self.sim.write("{}{:04d}".format(basename, self.sim.n)) |
| 148 | + |
| 149 | + tm_main.end() |
19 | 150 |
|
20 | | - msg.bold('pyro ...') |
| 151 | + result = self.compare_to_benchmark() |
| 152 | + self.make_bench(result) |
| 153 | + |
| 154 | + #------------------------------------------------------------------------- |
| 155 | + # final reports |
| 156 | + #------------------------------------------------------------------------- |
| 157 | + if self.verbose > 0: |
| 158 | + self.rp.print_unused_params() |
| 159 | + self.tc.report() |
21 | 160 |
|
22 | | - tc = profile.TimerCollection() |
| 161 | + self.sim.finalize() |
23 | 162 |
|
24 | | - tm_main = tc.timer("main") |
25 | | - tm_main.begin() |
26 | | - |
27 | | - # import desired solver under "solver" namespace |
28 | | - solver = importlib.import_module(solver_name) |
29 | | - |
30 | | - #------------------------------------------------------------------------- |
31 | | - # runtime parameters |
32 | | - #------------------------------------------------------------------------- |
33 | | - |
34 | | - # parameter defaults |
35 | | - rp = runparams.RuntimeParameters() |
36 | | - rp.load_params("_defaults") |
37 | | - rp.load_params(solver_name + "/_defaults") |
38 | | - |
39 | | - # problem-specific runtime parameters |
40 | | - rp.load_params(solver_name + "/problems/_" + problem_name + ".defaults") |
41 | | - |
42 | | - # now read in the inputs file |
43 | | - if not os.path.isfile(param_file): |
44 | | - # check if the param file lives in the solver's problems directory |
45 | | - param_file = solver_name + "/problems/" + param_file |
46 | | - if not os.path.isfile(param_file): |
47 | | - msg.fail("ERROR: inputs file does not exist") |
48 | | - |
49 | | - rp.load_params(param_file, no_new=1) |
50 | | - |
51 | | - # and any commandline overrides |
52 | | - if other_commands is not None: |
53 | | - rp.command_line_params(other_commands) |
54 | | - |
55 | | - # write out the inputs.auto |
56 | | - rp.print_paramfile() |
57 | | - |
58 | | - #------------------------------------------------------------------------- |
59 | | - # initialization |
60 | | - #------------------------------------------------------------------------- |
61 | | - |
62 | | - # initialize the Simulation object -- this will hold the grid and |
63 | | - # data and know about the runtime parameters and which problem we |
64 | | - # are running |
65 | | - sim = solver.Simulation(solver_name, problem_name, rp, timers=tc) |
66 | | - |
67 | | - sim.initialize() |
68 | | - sim.preevolve() |
69 | | - |
70 | | - #------------------------------------------------------------------------- |
71 | | - # evolve |
72 | | - #------------------------------------------------------------------------- |
73 | | - verbose = rp.get_param("driver.verbose") |
74 | | - |
75 | | - plt.ion() |
76 | | - |
77 | | - sim.cc_data.t = 0.0 |
78 | | - |
79 | | - # output the 0th data |
80 | | - basename = rp.get_param("io.basename") |
81 | | - sim.write("{}{:04d}".format(basename, sim.n)) |
82 | | - |
83 | | - dovis = rp.get_param("vis.dovis") |
84 | | - if dovis: |
85 | | - plt.figure(num=1, figsize=(8, 6), dpi=100, facecolor='w') |
86 | | - sim.dovis() |
87 | | - |
88 | | - while not sim.finished(): |
| 163 | + if self.comp_bench: |
| 164 | + return result |
| 165 | + else: |
| 166 | + return self.sim |
89 | 167 |
|
| 168 | + def single_step(self): |
| 169 | + """ |
| 170 | + Do a single step |
| 171 | + """ |
90 | 172 | # fill boundary conditions |
91 | | - sim.cc_data.fill_BC_all() |
| 173 | + self.sim.cc_data.fill_BC_all() |
92 | 174 |
|
93 | 175 | # get the timestep |
94 | | - sim.compute_timestep() |
| 176 | + self.sim.compute_timestep() |
95 | 177 |
|
96 | 178 | # evolve for a single timestep |
97 | | - sim.evolve() |
| 179 | + self.sim.evolve() |
98 | 180 |
|
99 | | - if verbose > 0: |
100 | | - print("%5d %10.5f %10.5f" % (sim.n, sim.cc_data.t, sim.dt)) |
| 181 | + if self.verbose > 0: |
| 182 | + print("%5d %10.5f %10.5f" % (self.sim.n, self.sim.cc_data.t, self.sim.dt)) |
101 | 183 |
|
102 | 184 | # output |
103 | | - if sim.do_output(): |
104 | | - if verbose > 0: |
| 185 | + if self.sim.do_output(): |
| 186 | + if self.verbose > 0: |
105 | 187 | msg.warning("outputting...") |
106 | | - basename = rp.get_param("io.basename") |
107 | | - sim.write("{}{:04d}".format(basename, sim.n)) |
| 188 | + basename = self.rp.get_param("io.basename") |
| 189 | + self.sim.write("{}{:04d}".format(basename, self.sim.n)) |
108 | 190 |
|
109 | 191 | # visualization |
110 | | - if dovis: |
111 | | - tm_vis = tc.timer("vis") |
| 192 | + if self.dovis: |
| 193 | + tm_vis = self.tc.timer("vis") |
112 | 194 | tm_vis.begin() |
113 | 195 |
|
114 | | - sim.dovis() |
115 | | - store = rp.get_param("vis.store_images") |
| 196 | + self.sim.dovis() |
| 197 | + store = self.rp.get_param("vis.store_images") |
116 | 198 |
|
117 | 199 | if store == 1: |
118 | | - basename = rp.get_param("io.basename") |
119 | | - plt.savefig("{}{:04d}.png".format(basename, sim.n)) |
| 200 | + basename = self.rp.get_param("io.basename") |
| 201 | + plt.savefig("{}{:04d}.png".format(basename, self.sim.n)) |
120 | 202 |
|
121 | 203 | tm_vis.end() |
122 | 204 |
|
123 | | - # final output |
124 | | - if verbose > 0: |
125 | | - msg.warning("outputting...") |
126 | | - basename = rp.get_param("io.basename") |
127 | | - sim.write("{}{:04d}".format(basename, sim.n)) |
128 | | - |
129 | | - tm_main.end() |
130 | | - |
131 | | - #------------------------------------------------------------------------- |
132 | | - # benchmarks (for regression testing) |
133 | | - #------------------------------------------------------------------------- |
134 | | - result = 0 |
135 | | - # are we comparing to a benchmark? |
136 | | - if comp_bench: |
137 | | - compare_file = "{}/tests/{}{:04d}".format( |
138 | | - solver_name, basename, sim.n) |
139 | | - msg.warning("comparing to: {} ".format(compare_file)) |
140 | | - try: |
141 | | - sim_bench = io.read(compare_file) |
142 | | - except IOError: |
143 | | - msg.warning("ERROR openning compare file") |
144 | | - return "ERROR openning compare file" |
145 | | - |
146 | | - result = compare.compare(sim.cc_data, sim_bench.cc_data) |
147 | | - |
148 | | - if result == 0: |
149 | | - msg.success("results match benchmark\n") |
150 | | - else: |
151 | | - msg.warning("ERROR: " + compare.errors[result] + "\n") |
| 205 | + def compare_to_benchmark(self): |
| 206 | + """ Are we comparing to a benchmark? """ |
152 | 207 |
|
153 | | - # are we storing a benchmark? |
154 | | - if make_bench or (result != 0 and reset_bench_on_fail): |
155 | | - if not os.path.isdir(solver_name + "/tests/"): |
156 | | - try: |
157 | | - os.mkdir(solver_name + "/tests/") |
158 | | - except (FileNotFoundError, PermissionError): |
159 | | - msg.fail("ERROR: unable to create the solver's tests/ directory") |
| 208 | + result = 0 |
160 | 209 |
|
161 | | - bench_file = solver_name + "/tests/" + basename + "%4.4d" % (sim.n) |
162 | | - msg.warning("storing new benchmark: {}\n".format(bench_file)) |
163 | | - sim.write(bench_file) |
| 210 | + if self.comp_bench: |
| 211 | + basename = self.rp.get_param("io.basename") |
| 212 | + compare_file = "{}/tests/{}{:04d}".format( |
| 213 | + self.solver_name, basename, self.sim.n) |
| 214 | + msg.warning("comparing to: {} ".format(compare_file)) |
| 215 | + try: |
| 216 | + sim_bench = io.read(compare_file) |
| 217 | + except IOError: |
| 218 | + msg.warning("ERROR openning compare file") |
| 219 | + return "ERROR openning compare file" |
164 | 220 |
|
165 | | - #------------------------------------------------------------------------- |
166 | | - # final reports |
167 | | - #------------------------------------------------------------------------- |
168 | | - if verbose > 0: |
169 | | - rp.print_unused_params() |
170 | | - tc.report() |
| 221 | + result = compare.compare(self.sim.cc_data, sim_bench.cc_data) |
171 | 222 |
|
172 | | - sim.finalize() |
| 223 | + if result == 0: |
| 224 | + msg.success("results match benchmark\n") |
| 225 | + else: |
| 226 | + msg.warning("ERROR: " + compare.errors[result] + "\n") |
173 | 227 |
|
174 | | - if comp_bench: |
175 | 228 | return result |
176 | | - else: |
177 | | - return sim |
178 | 229 |
|
| 230 | + def store_as_benchmark(self, result): |
| 231 | + """ Are we storing a benchmark? """ |
| 232 | + if self.make_bench or (result != 0 and self.reset_bench_on_fail): |
| 233 | + if not os.path.isdir(self.solver_name + "/tests/"): |
| 234 | + try: |
| 235 | + os.mkdir(self.solver_name + "/tests/") |
| 236 | + except (FileNotFoundError, PermissionError): |
| 237 | + msg.fail("ERROR: unable to create the solver's tests/ directory") |
179 | 238 |
|
180 | | -def parse_and_run(): |
181 | | - """Parse the runtime parameters and run a pyro instance""" |
| 239 | + basename = self.rp.get_param("io.basename") |
| 240 | + bench_file = self.solver_name + "/tests/" + basename + "%4.4d" % (self.sim.n) |
| 241 | + msg.warning("storing new benchmark: {}\n".format(bench_file)) |
| 242 | + self.sim.write(bench_file) |
| 243 | + |
| 244 | + |
| 245 | +def parse_args(): |
| 246 | + """Parse the runtime parameters""" |
182 | 247 |
|
183 | 248 | valid_solvers = ["advection", |
184 | 249 | "advection_rk", |
@@ -214,13 +279,13 @@ def parse_and_run(): |
214 | 279 | help="additional runtime parameters that override the inputs file " |
215 | 280 | "in the format section.option=value") |
216 | 281 |
|
217 | | - args = p.parse_args() |
218 | | - |
219 | | - doit(args.solver[0], args.problem[0], args.param[0], |
220 | | - other_commands=args.other, |
221 | | - comp_bench=args.compare_benchmark, |
222 | | - make_bench=args.make_benchmark) |
| 282 | + return p.parse_args() |
223 | 283 |
|
224 | 284 |
|
225 | 285 | if __name__ == "__main__": |
226 | | - parse_and_run() |
| 286 | + args = parse_args() |
| 287 | + pyro = Pyro(args.solver[0], comp_bench=args.compare_benchmark, |
| 288 | + make_bench=args.make_benchmark) |
| 289 | + pyro.initialize_problem(problem_name=args.problem[0], param_file=args.param[0], |
| 290 | + other_commands=args.other) |
| 291 | + pyro.run_sim() |
0 commit comments