Skip to content

Commit eb74d49

Browse files
authored
Introduce openmc.lib.TemporarySession context manager (#3475)
1 parent ecfb666 commit eb74d49

9 files changed

Lines changed: 163 additions & 79 deletions

File tree

docs/source/pythonapi/capi.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ Classes
8989
SphericalMesh
9090
SurfaceFilter
9191
Tally
92+
TemporarySession
9293
UniverseFilter
9394
UnstructuredMesh
9495
WeightFilter

openmc/deplete/chain.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from collections import defaultdict, namedtuple
1212
from collections.abc import Mapping, Iterable
1313
from numbers import Real, Integral
14+
from pathlib import Path
1415
from warnings import warn
1516
from typing import List
1617

@@ -278,7 +279,6 @@ def __len__(self):
278279
"""Number of nuclides in chain."""
279280
return len(self.nuclides)
280281

281-
282282
@property
283283
def stable_nuclides(self) -> List[Nuclide]:
284284
"""List of stable nuclides available in the chain"""
@@ -298,6 +298,7 @@ def add_nuclide(self, nuclide: Nuclide):
298298
Nuclide to add
299299
300300
"""
301+
_invalidate_chain_cache(self)
301302
self.nuclide_dict[nuclide.name] = len(self.nuclides)
302303
self.nuclides.append(nuclide)
303304

@@ -463,7 +464,6 @@ def from_endf(cls, decay_files, fpy_files, neutron_files,
463464
nuclide.add_reaction('fission', None, q_value, 1.0)
464465
fissionable = True
465466

466-
467467
if fissionable:
468468
if parent in fpy_data:
469469
fpy = fpy_data[parent]
@@ -558,6 +558,9 @@ def from_xml(cls, filename, fission_q=None):
558558
nuc = Nuclide.from_xml(nuclide_elem, root, this_q)
559559
chain.add_nuclide(nuc)
560560

561+
# Store path of XML file (used for handling cache invalidation)
562+
chain._xml_path = str(Path(filename).resolve())
563+
561564
return chain
562565

563566
def export_to_xml(self, filename):
@@ -888,7 +891,7 @@ def set_branch_ratios(self, branch_ratios, reaction="(n,gamma)",
888891
--------
889892
:meth:`get_branch_ratios`
890893
"""
891-
894+
_invalidate_chain_cache(self)
892895
# Store some useful information through the validation stage
893896

894897
sums = {}
@@ -1027,6 +1030,7 @@ def fission_yields(self):
10271030

10281031
@fission_yields.setter
10291032
def fission_yields(self, yields):
1033+
_invalidate_chain_cache(self)
10301034
if yields is not None:
10311035
if isinstance(yields, Mapping):
10321036
yields = [yields]
@@ -1249,6 +1253,10 @@ def _follow(self, isotopes, level):
12491253
return found
12501254

12511255

1256+
# A global cache for Chain objects
1257+
_CHAIN_CACHE = {}
1258+
1259+
12521260
def _get_chain(
12531261
chain_file: PathLike | Chain | None = None,
12541262
fission_q: dict | None = None
@@ -1269,16 +1277,39 @@ def _get_chain(
12691277
Chain
12701278
Depletion chain instance.
12711279
"""
1280+
# If chain_file is already a Chain, return it directly
12721281
if isinstance(chain_file, Chain):
12731282
return chain_file
1274-
elif isinstance(chain_file, PathLike | None):
1275-
if chain_file is None:
1276-
chain_file = openmc.config.get('chain_file')
1277-
if 'chain_file' not in openmc.config:
1278-
raise DataError(
1279-
"No depletion chain specified and could not find depletion "
1280-
"chain in openmc.config['chain_file']"
1281-
)
1282-
return Chain.from_xml(chain_file, fission_q)
1283-
else:
1283+
1284+
# Resolve chain_file based on config if None
1285+
if chain_file is None:
1286+
chain_file = openmc.config.get('chain_file')
1287+
if 'chain_file' not in openmc.config:
1288+
raise DataError(
1289+
"No depletion chain specified and could not find depletion "
1290+
"chain in openmc.config['chain_file']"
1291+
)
1292+
elif not isinstance(chain_file, PathLike):
12841293
raise TypeError("chain_file must be path-like, a Chain, or None")
1294+
1295+
# Determine the key for the cache, which consists of the absolute path, the
1296+
# file modification time, the file size, and the fission Q values.
1297+
chain_path = Path(chain_file).resolve()
1298+
stat_result = chain_path.stat()
1299+
fq_tuple = tuple(sorted(fission_q.items())) if fission_q else ()
1300+
key = (chain_path, stat_result.st_mtime, stat_result.st_size, fq_tuple)
1301+
1302+
# Check the global cache. If not cached, load the chain from XML and store
1303+
global _CHAIN_CACHE
1304+
if key not in _CHAIN_CACHE:
1305+
_CHAIN_CACHE[key] = Chain.from_xml(chain_path, fission_q)
1306+
return _CHAIN_CACHE[key]
1307+
1308+
1309+
def _invalidate_chain_cache(chain):
1310+
"""Invalidate the cache for a specific Chain (when it is modifed)."""
1311+
if hasattr(chain, '_xml_path'):
1312+
# Remove all entries with the same path as self._xml_path
1313+
for key in list(_CHAIN_CACHE.keys()):
1314+
if str(key[0]) == chain._xml_path:
1315+
del _CHAIN_CACHE[key]

openmc/deplete/microxs.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import numpy as np
1515

1616
from openmc.checkvalue import check_type, check_value, check_iterable_type, PathLike
17-
from openmc.utility_funcs import change_directory
1817
from openmc import StatePoint
1918
from openmc.mgxs import GROUP_STRUCTURES
2019
from openmc.data import REACTION_MT
@@ -286,6 +285,10 @@ def from_multigroup_flux(
286285
sections available. MicroXS entry will be 0 if the nuclide cross section
287286
is not found.
288287
288+
It is recommended to make repeated calls to this method within a context
289+
manager using the :class:`openmc.lib.TemporarySession` class to avoid
290+
re-initializing OpenMC and loading cross sections each time.
291+
289292
.. versionadded:: 0.15.0
290293
291294
Parameters
@@ -349,37 +352,23 @@ def from_multigroup_flux(
349352
# Create 3D array for microscopic cross sections
350353
microxs_arr = np.zeros((len(nuclides), len(mts), 1))
351354

352-
# Create a material with all nuclides
353-
mat_all_nucs = openmc.Material()
354-
for nuc in nuclides:
355-
if nuc in nuclides_with_data:
356-
mat_all_nucs.add_nuclide(nuc, 1.0)
357-
mat_all_nucs.set_density("atom/b-cm", 1.0)
358-
359-
# Create simple model containing the above material
360-
surf1 = openmc.Sphere(boundary_type="vacuum")
361-
surf1_cell = openmc.Cell(fill=mat_all_nucs, region=-surf1)
362-
model = openmc.Model()
363-
model.geometry = openmc.Geometry([surf1_cell])
364-
model.settings = openmc.Settings(
365-
particles=1, batches=1, output={'summary': False})
366-
367-
with change_directory(tmpdir=True):
368-
# Export model within temporary directory
369-
model.export_to_model_xml()
370-
371-
with openmc.lib.run_in_memory(**init_kwargs):
372-
# For each nuclide and reaction, compute the flux-averaged
373-
# cross section
374-
for nuc_index, nuc in enumerate(nuclides):
375-
if nuc not in nuclides_with_data:
376-
continue
377-
lib_nuc = openmc.lib.nuclides[nuc]
378-
for mt_index, mt in enumerate(mts):
379-
xs = lib_nuc.collapse_rate(
380-
mt, temperature, energies, multigroup_flux
381-
)
382-
microxs_arr[nuc_index, mt_index, 0] = xs
355+
def compute_microxs():
356+
# For each nuclide and reaction, compute the flux-averaged xs
357+
for nuc_index, nuc in enumerate(nuclides):
358+
if nuc not in nuclides_with_data:
359+
continue
360+
lib_nuc = openmc.lib.load_nuclide(nuc)
361+
for mt_index, mt in enumerate(mts):
362+
microxs_arr[nuc_index, mt_index, 0] = lib_nuc.collapse_rate(
363+
mt, temperature, energies, multigroup_flux
364+
)
365+
366+
# Compute microscopic cross sections within a temporary session
367+
if not openmc.lib.is_initialized:
368+
with openmc.lib.TemporarySession(**init_kwargs):
369+
compute_microxs()
370+
else:
371+
compute_microxs()
383372

384373
return cls(microxs_arr, nuclides, reactions)
385374

openmc/lib/core.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
c_uint64, c_size_t)
55
import sys
66
import os
7+
from pathlib import Path
78
from random import getrandbits
9+
from tempfile import TemporaryDirectory
810

911
import numpy as np
1012
from numpy.ctypeslib import as_array
@@ -617,6 +619,69 @@ def run_in_memory(**kwargs):
617619
finalize()
618620

619621

622+
class TemporarySession:
623+
"""Context manager for running via openmc.lib in a temporary directory.
624+
625+
This class is useful for accessing functionality from openmc.lib without
626+
polluting your current working directory with OpenMC files. It is used
627+
internally as a persistent session to avoid loading cross sections multiple
628+
times.
629+
630+
Parameters
631+
----------
632+
model : openmc.Model, optional
633+
OpenMC model to use for the session. If None, a minimal working model is
634+
created.
635+
**init_kwargs
636+
Keyword arguments to pass to :func:`openmc.lib.init`.
637+
638+
Attributes
639+
----------
640+
model : openmc.Model
641+
The OpenMC model used for the session.
642+
643+
"""
644+
def __init__(self, model=None, **init_kwargs):
645+
self.init_kwargs = init_kwargs
646+
if model is None:
647+
surf = openmc.Sphere(boundary_type="vacuum")
648+
cell = openmc.Cell(region=-surf)
649+
model = openmc.Model()
650+
model.geometry = openmc.Geometry([cell])
651+
model.settings = openmc.Settings(
652+
particles=1, batches=1, output={'summary': False})
653+
self.model = model
654+
655+
def __enter__(self):
656+
"""Initialize the OpenMC library in a temporary directory."""
657+
# Make sure OpenMC is not already initialized
658+
if openmc.lib.is_initialized:
659+
raise RuntimeError("openmc.lib is already initialized.")
660+
661+
# Store original working directory
662+
self.orig_dir = Path.cwd()
663+
664+
# Set up temporary directory
665+
self.tmp_dir = TemporaryDirectory()
666+
working_dir = Path(self.tmp_dir.name)
667+
working_dir.mkdir(parents=True, exist_ok=True)
668+
os.chdir(working_dir)
669+
670+
# Export model and initialize OpenMC
671+
self.model.export_to_model_xml()
672+
openmc.lib.init(**self.init_kwargs)
673+
674+
return self
675+
676+
def __exit__(self, exc_type, exc_value, traceback):
677+
"""Finalize the OpenMC library and clean up temporary directory."""
678+
try:
679+
openmc.lib.finalize()
680+
finally:
681+
os.chdir(self.orig_dir)
682+
self.tmp_dir.cleanup()
683+
684+
620685
class _DLLGlobal:
621686
"""Data descriptor that exposes global variables from libopenmc."""
622687
def __init__(self, ctype, name):

openmc/lib/nuclide.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,14 @@ def load_nuclide(name):
4040
name : str
4141
Name of the nuclide, e.g. 'U235'
4242
43+
Returns
44+
-------
45+
Nuclide
46+
The class:`Nuclide` that was just loaded.
47+
4348
"""
4449
_dll.openmc_load_nuclide(name.encode(), None, 0)
50+
return nuclides[name]
4551

4652

4753
class Nuclide(_FortranObject):

openmc/mesh.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -400,31 +400,30 @@ def material_volumes(
400400
"""
401401
import openmc.lib
402402

403-
with change_directory(tmpdir=True):
404-
# In order to get mesh into model, we temporarily replace the
405-
# tallies with a single mesh tally using the current mesh
406-
original_tallies = model.tallies
407-
new_tally = openmc.Tally()
408-
new_tally.filters = [openmc.MeshFilter(self)]
409-
new_tally.scores = ['flux']
410-
model.tallies = [new_tally]
411-
412-
# Export model to XML
413-
model.export_to_model_xml()
414-
415-
# Get material volume fractions
416-
kwargs.setdefault('output', True)
417-
if 'args' in kwargs:
418-
kwargs['args'] = ['-c'] + kwargs['args']
419-
kwargs.setdefault('args', ['-c'])
420-
openmc.lib.init(**kwargs)
403+
# In order to get mesh into model, we temporarily replace the
404+
# tallies with a single mesh tally using the current mesh
405+
original_tallies = model.tallies
406+
new_tally = openmc.Tally()
407+
new_tally.filters = [openmc.MeshFilter(self)]
408+
new_tally.scores = ['flux']
409+
model.tallies = [new_tally]
410+
411+
# Set default arguments
412+
kwargs.setdefault('output', True)
413+
if 'args' in kwargs:
414+
kwargs['args'] = ['-c'] + kwargs['args']
415+
kwargs.setdefault('args', ['-c'])
416+
417+
with openmc.lib.TemporarySession(model, **kwargs):
418+
# Get mesh from single tally
421419
mesh = openmc.lib.tallies[new_tally.id].filters[0].mesh
420+
421+
# Compute material volumes
422422
volumes = mesh.material_volumes(
423423
n_samples, max_materials, output=kwargs['output'])
424-
openmc.lib.finalize()
425424

426-
# Restore original tallies
427-
model.tallies = original_tallies
425+
# Restore original tallies
426+
model.tallies = original_tallies
428427

429428
return volumes
430429

openmc/weight_windows.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,10 +1041,6 @@ def export_to_hdf5(self, path: PathLike = 'weight_windows.h5', **init_kwargs):
10411041
# Get absolute path before moving to temporary directory
10421042
path = Path(path).resolve()
10431043

1044-
with change_directory(tmpdir=True):
1045-
# Write the model to an XML file
1046-
model.export_to_model_xml()
1047-
1048-
# Load the model with openmc.lib and then export it to an HDF5 file
1049-
with openmc.lib.run_in_memory(**init_kwargs):
1050-
openmc.lib.export_weight_windows(path)
1044+
# Load the model with openmc.lib and then export it to an HDF5 file
1045+
with openmc.lib.TemporarySession(model, **init_kwargs):
1046+
openmc.lib.export_weight_windows(path)

src/nuclide.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,7 @@ extern "C" size_t nuclides_size()
11141114
extern "C" int openmc_load_nuclide(const char* name, const double* temps, int n)
11151115
{
11161116
if (data::nuclide_map.find(name) == data::nuclide_map.end() ||
1117-
data::nuclide_map.at(name) >= data::elements.size()) {
1117+
data::nuclide_map.at(name) >= data::nuclides.size()) {
11181118
LibraryKey key {Library::Type::neutron, name};
11191119
const auto& it = data::library_map.find(key);
11201120
if (it == data::library_map.end()) {
@@ -1215,7 +1215,6 @@ extern "C" int openmc_nuclide_collapse_rate(int index, int MT,
12151215
*xs = data::nuclides[index]->collapse_rate(
12161216
MT, temperature, {energy, energy + n + 1}, {flux, flux + n});
12171217
} catch (const std::out_of_range& e) {
1218-
fmt::print("Caught error\n");
12191218
set_errmsg(e.what());
12201219
return OPENMC_E_OUT_OF_BOUNDS;
12211220
}

tests/unit_tests/test_deplete_chain.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,7 @@ def test_capture_branch_infer_ground():
310310
# Create nuclide to be added into the chain
311311
xe136m = nuclide.Nuclide("Xe136_m1")
312312

313-
chain.nuclides.append(xe136m)
314-
chain.nuclide_dict[xe136m.name] = len(chain.nuclides) - 1
313+
chain.add_nuclide(xe136m)
315314

316315
chain.set_branch_ratios(infer_br, "(n,gamma)")
317316

@@ -327,8 +326,7 @@ def test_capture_branch_no_rxn():
327326

328327
u5m = nuclide.Nuclide("U235_m1")
329328

330-
chain.nuclides.append(u5m)
331-
chain.nuclide_dict[u5m.name] = len(chain.nuclides) - 1
329+
chain.add_nuclide(u5m)
332330

333331
with pytest.raises(AttributeError, match="U234"):
334332
chain.set_branch_ratios(u4br)

0 commit comments

Comments
 (0)