1111from collections import defaultdict , namedtuple
1212from collections .abc import Mapping , Iterable
1313from numbers import Real , Integral
14+ from pathlib import Path
1415from warnings import warn
1516from typing import List
1617
@@ -278,7 +279,6 @@ def __len__(self):
278279 """Number of nuclides in chain."""
279280 return len (self .nuclides )
280281
281-
282282 @property
283283 def stable_nuclides (self ) -> List [Nuclide ]:
284284 """List of stable nuclides available in the chain"""
@@ -298,6 +298,7 @@ def add_nuclide(self, nuclide: Nuclide):
298298 Nuclide to add
299299
300300 """
301+ _invalidate_chain_cache (self )
301302 self .nuclide_dict [nuclide .name ] = len (self .nuclides )
302303 self .nuclides .append (nuclide )
303304
@@ -463,7 +464,6 @@ def from_endf(cls, decay_files, fpy_files, neutron_files,
463464 nuclide .add_reaction ('fission' , None , q_value , 1.0 )
464465 fissionable = True
465466
466-
467467 if fissionable :
468468 if parent in fpy_data :
469469 fpy = fpy_data [parent ]
@@ -558,6 +558,9 @@ def from_xml(cls, filename, fission_q=None):
558558 nuc = Nuclide .from_xml (nuclide_elem , root , this_q )
559559 chain .add_nuclide (nuc )
560560
561+ # Store path of XML file (used for handling cache invalidation)
562+ chain ._xml_path = str (Path (filename ).resolve ())
563+
561564 return chain
562565
563566 def export_to_xml (self , filename ):
@@ -888,7 +891,7 @@ def set_branch_ratios(self, branch_ratios, reaction="(n,gamma)",
888891 --------
889892 :meth:`get_branch_ratios`
890893 """
891-
894+ _invalidate_chain_cache ( self )
892895 # Store some useful information through the validation stage
893896
894897 sums = {}
@@ -1027,6 +1030,7 @@ def fission_yields(self):
10271030
10281031 @fission_yields .setter
10291032 def fission_yields (self , yields ):
1033+ _invalidate_chain_cache (self )
10301034 if yields is not None :
10311035 if isinstance (yields , Mapping ):
10321036 yields = [yields ]
@@ -1249,6 +1253,10 @@ def _follow(self, isotopes, level):
12491253 return found
12501254
12511255
1256+ # A global cache for Chain objects
1257+ _CHAIN_CACHE = {}
1258+
1259+
12521260def _get_chain (
12531261 chain_file : PathLike | Chain | None = None ,
12541262 fission_q : dict | None = None
@@ -1269,16 +1277,39 @@ def _get_chain(
12691277 Chain
12701278 Depletion chain instance.
12711279 """
1280+ # If chain_file is already a Chain, return it directly
12721281 if isinstance (chain_file , Chain ):
12731282 return chain_file
1274- elif isinstance ( chain_file , PathLike | None ):
1275- if chain_file is None :
1276- chain_file = openmc . config . get ( 'chain_file' )
1277- if ' chain_file' not in openmc .config :
1278- raise DataError (
1279- "No depletion chain specified and could not find depletion "
1280- "chain in openmc.config['chain_file'] "
1281- )
1282- return Chain . from_xml ( chain_file , fission_q )
1283- else :
1283+
1284+ # Resolve chain_file based on config if None
1285+ if chain_file is None :
1286+ chain_file = openmc .config . get ( 'chain_file' )
1287+ if 'chain_file' not in openmc . config :
1288+ raise DataError (
1289+ "No depletion chain specified and could not find depletion "
1290+ "chain in openmc.config['chain_file']"
1291+ )
1292+ elif not isinstance ( chain_file , PathLike ) :
12841293 raise TypeError ("chain_file must be path-like, a Chain, or None" )
1294+
1295+ # Determine the key for the cache, which consists of the absolute path, the
1296+ # file modification time, the file size, and the fission Q values.
1297+ chain_path = Path (chain_file ).resolve ()
1298+ stat_result = chain_path .stat ()
1299+ fq_tuple = tuple (sorted (fission_q .items ())) if fission_q else ()
1300+ key = (chain_path , stat_result .st_mtime , stat_result .st_size , fq_tuple )
1301+
1302+ # Check the global cache. If not cached, load the chain from XML and store
1303+ global _CHAIN_CACHE
1304+ if key not in _CHAIN_CACHE :
1305+ _CHAIN_CACHE [key ] = Chain .from_xml (chain_path , fission_q )
1306+ return _CHAIN_CACHE [key ]
1307+
1308+
1309+ def _invalidate_chain_cache (chain ):
1310+ """Invalidate the cache for a specific Chain (when it is modifed)."""
1311+ if hasattr (chain , '_xml_path' ):
1312+ # Remove all entries with the same path as self._xml_path
1313+ for key in list (_CHAIN_CACHE .keys ()):
1314+ if str (key [0 ]) == chain ._xml_path :
1315+ del _CHAIN_CACHE [key ]
0 commit comments