Skip to content

Commit 2bd0666

Browse files
Parallelize sampling external sources and threadsafe rejection counters (#3830)
Co-authored-by: Paul Romano <paul.k.romano@gmail.com>
1 parent 0ab46df commit 2bd0666

6 files changed

Lines changed: 115 additions & 38 deletions

File tree

include/openmc/source.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#ifndef OPENMC_SOURCE_H
55
#define OPENMC_SOURCE_H
66

7+
#include <atomic>
78
#include <limits>
89
#include <unordered_set>
910

@@ -25,10 +26,18 @@ namespace openmc {
2526
// source_rejection_fraction
2627
constexpr int EXTSRC_REJECT_THRESHOLD {10000};
2728

29+
// Maximum number of source rejections allowed while sampling a single site
30+
constexpr int64_t MAX_SOURCE_REJECTIONS_PER_SAMPLE {1'000'000};
31+
2832
//==============================================================================
2933
// Global variables
3034
//==============================================================================
3135

36+
// Cumulative counters for source rejection diagnostics. These are atomic to
37+
// allow thread-safe concurrent sampling of external sources.
38+
extern std::atomic<int64_t> source_n_accept;
39+
extern std::atomic<int64_t> source_n_reject;
40+
3241
class Source;
3342

3443
namespace model {
@@ -265,6 +274,9 @@ SourceSite sample_external_source(uint64_t* seed);
265274

266275
void free_memory_source();
267276

277+
//! Reset cumulative source rejection counters
278+
void reset_source_rejection_counters();
279+
268280
} // namespace openmc
269281

270282
#endif // OPENMC_SOURCE_H

openmc/lib/core.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ class _SourceSite(Structure):
3333
('parent_id', c_int64),
3434
('progeny_id', c_int64)]
3535

36-
3736
# Define input type for numpy arrays that will be passed into C++ functions
3837
# Must be an int or double array, with single dimension that is contiguous
3938
_array_1d_int = np.ctypeslib.ndpointer(dtype=np.int32, ndim=1,
@@ -494,8 +493,9 @@ def run_random_ray(output=True):
494493

495494
def sample_external_source(
496495
n_samples: int = 1000,
497-
prn_seed: int | None = None
498-
) -> openmc.ParticleList:
496+
prn_seed: int | None = None,
497+
as_array: bool = False
498+
) -> openmc.ParticleList | np.ndarray:
499499
"""Sample external source and return source particles.
500500
501501
.. versionadded:: 0.13.1
@@ -507,30 +507,49 @@ def sample_external_source(
507507
prn_seed : int
508508
Pseudorandom number generator (PRNG) seed; if None, one will be
509509
generated randomly.
510+
as_array : bool
511+
If True, return a numpy structured array instead of a
512+
:class:`~openmc.ParticleList`. The array has fields ``'r'`` (float64,
513+
shape 3), ``'u'`` (float64, shape 3), ``'E'`` (float64), ``'time'``
514+
(float64), ``'wgt'`` (float64), ``'delayed_group'`` (int32),
515+
``'surf_id'`` (int32), and ``'particle'`` (int32). This avoids the
516+
overhead of constructing individual :class:`~openmc.SourceParticle`
517+
objects and is substantially faster for large sample counts.
510518
511519
Returns
512520
-------
513-
openmc.ParticleList
514-
List of sampled source particles
521+
openmc.ParticleList or numpy.ndarray
522+
List of sampled source particles, or a structured array when
523+
*as_array* is True.
515524
516525
"""
517526
if n_samples <= 0:
518527
raise ValueError("Number of samples must be positive")
519528
if prn_seed is None:
520529
prn_seed = getrandbits(63)
521530

522-
# Call into C API to sample source
523-
sites_array = (_SourceSite * n_samples)()
524-
_dll.openmc_sample_external_source(c_size_t(n_samples), c_uint64(prn_seed), sites_array)
525-
526-
# Convert to list of SourceParticle and return
527-
return openmc.ParticleList([openmc.SourceParticle(
528-
r=site.r, u=site.u, E=site.E, time=site.time, wgt=site.wgt,
529-
delayed_group=site.delayed_group, surf_id=site.surf_id,
530-
particle=openmc.ParticleType(site.particle)
531+
# Pre-allocate output array and sample all particles in a single C call
532+
result = np.empty(n_samples, dtype=_SourceSite)
533+
sites_array = (_SourceSite * n_samples).from_buffer(result)
534+
_dll.openmc_sample_external_source(
535+
c_size_t(n_samples),
536+
c_uint64(prn_seed),
537+
sites_array,
538+
)
539+
540+
if as_array:
541+
return result
542+
543+
particles = [
544+
openmc.SourceParticle(
545+
r=site.r, u=site.u, E=site.E, time=site.time,
546+
wgt=site.wgt, delayed_group=site.delayed_group,
547+
surf_id=site.surf_id,
548+
particle=openmc.ParticleType(site.particle),
531549
)
532550
for site in sites_array
533-
])
551+
]
552+
return openmc.ParticleList(particles)
534553

535554

536555
def simulation_init():

openmc/model/model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,8 +1294,9 @@ def sample_external_source(
12941294
self,
12951295
n_samples: int = 1000,
12961296
prn_seed: int | None = None,
1297+
as_array: bool = False,
12971298
**init_kwargs
1298-
) -> openmc.ParticleList:
1299+
) -> openmc.ParticleList | np.ndarray:
12991300
"""Sample external source and return source particles.
13001301
13011302
.. versionadded:: 0.15.1
@@ -1307,13 +1308,17 @@ def sample_external_source(
13071308
prn_seed : int
13081309
Pseudorandom number generator (PRNG) seed; if None, one will be
13091310
generated randomly.
1311+
as_array : bool
1312+
If True, return a numpy structured array instead of a
1313+
:class:`~openmc.ParticleList`.
13101314
**init_kwargs
13111315
Keyword arguments passed to :func:`openmc.lib.init`
13121316
13131317
Returns
13141318
-------
1315-
openmc.ParticleList
1316-
List of samples source particles
1319+
openmc.ParticleList or numpy.ndarray
1320+
List of sampled source particles, or a structured array when
1321+
*as_array* is True.
13171322
"""
13181323
import openmc.lib
13191324

@@ -1324,7 +1329,7 @@ def sample_external_source(
13241329

13251330
with openmc.lib.TemporarySession(self, **init_kwargs):
13261331
return openmc.lib.sample_external_source(
1327-
n_samples=n_samples, prn_seed=prn_seed
1332+
n_samples=n_samples, prn_seed=prn_seed, as_array=as_array
13281333
)
13291334

13301335
def apply_tally_results(self, statepoint: PathLike | openmc.StatePoint):
@@ -2588,7 +2593,7 @@ def convert_to_multigroup(
25882593
# This mode doesn't require
25892594
# valid transport settings like particles/batches
25902595
original_run_mode = self.settings.run_mode
2591-
self.settings.run_mode = 'volume'
2596+
self.settings.run_mode = 'volume'
25922597
self.init_lib(directory=tmpdir)
25932598
self.sync_dagmc_universes()
25942599
self.finalize_lib()

src/simulation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ int openmc_simulation_init()
122122
simulation::ssw_current_file = 1;
123123
simulation::k_generation.clear();
124124
simulation::entropy.clear();
125+
reset_source_rejection_counters();
125126
openmc_reset();
126127

127128
// If this is a restart run, load the state point data and binary source

src/source.cpp

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737

3838
namespace openmc {
3939

40+
std::atomic<int64_t> source_n_accept {0};
41+
std::atomic<int64_t> source_n_reject {0};
42+
4043
namespace {
4144

4245
void validate_particle_type(ParticleType type, const std::string& context)
@@ -191,9 +194,8 @@ void check_rejection_fraction(int64_t n_reject, int64_t n_accept)
191194
SourceSite Source::sample_with_constraints(uint64_t* seed) const
192195
{
193196
bool accepted = false;
194-
static int64_t n_reject = 0;
195-
static int64_t n_accept = 0;
196-
SourceSite site;
197+
int64_t n_local_reject = 0;
198+
SourceSite site {};
197199

198200
while (!accepted) {
199201
// Sample a source site without considering constraints yet
@@ -207,9 +209,13 @@ SourceSite Source::sample_with_constraints(uint64_t* seed) const
207209
satisfies_energy_constraints(site.E) &&
208210
satisfies_time_constraints(site.time);
209211
if (!accepted) {
210-
// Increment number of rejections and check against minimum fraction
211-
++n_reject;
212-
check_rejection_fraction(n_reject, n_accept);
212+
++n_local_reject;
213+
214+
// Check per-particle rejection limit
215+
if (n_local_reject >= MAX_SOURCE_REJECTIONS_PER_SAMPLE) {
216+
fatal_error("Exceeded maximum number of source rejections per "
217+
"sample. Please check your source definition.");
218+
}
213219

214220
// For the "kill" strategy, accept particle but set weight to 0 so that
215221
// it is terminated immediately
@@ -221,8 +227,13 @@ SourceSite Source::sample_with_constraints(uint64_t* seed) const
221227
}
222228
}
223229

224-
// Increment number of accepted samples
225-
++n_accept;
230+
// Flush local rejection count, update accept counter, and check overall
231+
// rejection fraction
232+
if (n_local_reject > 0) {
233+
source_n_reject += n_local_reject;
234+
}
235+
++source_n_accept;
236+
check_rejection_fraction(source_n_reject, source_n_accept);
226237

227238
return site;
228239
}
@@ -361,15 +372,14 @@ IndependentSource::IndependentSource(pugi::xml_node node) : Source(node)
361372

362373
SourceSite IndependentSource::sample(uint64_t* seed) const
363374
{
364-
SourceSite site;
375+
SourceSite site {};
365376
site.particle = particle_;
366377
double r_wgt = 1.0;
367378
double E_wgt = 1.0;
368379

369380
// Repeat sampling source location until a good site has been accepted
370381
bool accepted = false;
371-
static int64_t n_reject = 0;
372-
static int64_t n_accept = 0;
382+
int64_t n_local_reject = 0;
373383

374384
while (!accepted) {
375385

@@ -383,8 +393,11 @@ SourceSite IndependentSource::sample(uint64_t* seed) const
383393

384394
// Check for rejection
385395
if (!accepted) {
386-
++n_reject;
387-
check_rejection_fraction(n_reject, n_accept);
396+
++n_local_reject;
397+
if (n_local_reject >= MAX_SOURCE_REJECTIONS_PER_SAMPLE) {
398+
fatal_error("Exceeded maximum number of source rejections per "
399+
"sample. Please check your source definition.");
400+
}
388401
}
389402
}
390403

@@ -419,8 +432,11 @@ SourceSite IndependentSource::sample(uint64_t* seed) const
419432
(satisfies_energy_constraints(site.E)))
420433
break;
421434

422-
n_reject++;
423-
check_rejection_fraction(n_reject, n_accept);
435+
++n_local_reject;
436+
if (n_local_reject >= MAX_SOURCE_REJECTIONS_PER_SAMPLE) {
437+
fatal_error("Exceeded maximum number of source rejections per "
438+
"sample. Please check your source definition.");
439+
}
424440
}
425441

426442
// Sample particle creation time
@@ -430,8 +446,10 @@ SourceSite IndependentSource::sample(uint64_t* seed) const
430446
site.wgt *= (E_wgt * time_wgt);
431447
}
432448

433-
// Increment number of accepted samples
434-
++n_accept;
449+
// Flush local rejection count into global counter
450+
if (n_local_reject > 0) {
451+
source_n_reject += n_local_reject;
452+
}
435453

436454
return site;
437455
}
@@ -692,6 +710,13 @@ SourceSite sample_external_source(uint64_t* seed)
692710
void free_memory_source()
693711
{
694712
model::external_sources.clear();
713+
reset_source_rejection_counters();
714+
}
715+
716+
void reset_source_rejection_counters()
717+
{
718+
source_n_accept = 0;
719+
source_n_reject = 0;
695720
}
696721

697722
//==============================================================================
@@ -712,8 +737,15 @@ extern "C" int openmc_sample_external_source(
712737
}
713738

714739
auto sites_array = static_cast<SourceSite*>(sites);
740+
741+
// Derive independent per-particle seeds from the base seed so that
742+
// each iteration has its own RNG state for thread-safe parallel sampling.
743+
uint64_t base_seed = *seed;
744+
745+
#pragma omp parallel for schedule(static)
715746
for (size_t i = 0; i < n; ++i) {
716-
sites_array[i] = sample_external_source(seed);
747+
uint64_t particle_seed = init_seed(base_seed + i, STREAM_SOURCE);
748+
sites_array[i] = sample_external_source(&particle_seed);
717749
}
718750
return 0;
719751
}

tests/unit_tests/test_lib.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,14 @@ def test_sample_external_source(run_in_tmpdir, mpi_intracomm):
11141114
assert p1.time == p2.time
11151115
assert p1.wgt == p2.wgt
11161116

1117+
# as_array should return a numpy structured array with matching values
1118+
arr = openmc.lib.sample_external_source(10, prn_seed=3, as_array=True)
1119+
assert isinstance(arr, np.ndarray)
1120+
assert len(arr) == 10
1121+
for p, row in zip(particles, arr):
1122+
assert p.r == pytest.approx(row['r'])
1123+
assert p.E == pytest.approx(row['E'])
1124+
11171125
openmc.lib.finalize()
11181126

11191127
# Make sure sampling works in volume calculation mode

0 commit comments

Comments
 (0)