Skip to content

Commit 322b741

Browse files
GuyStenpaulromano
andauthored
Support Mixture distributions in combine_distributions (#3784)
Co-authored-by: Paul Romano <paul.k.romano@gmail.com>
1 parent 6050c78 commit 322b741

2 files changed

Lines changed: 41 additions & 7 deletions

File tree

openmc/stats/univariate.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,31 +2065,48 @@ def clip(self, tolerance: float = 1e-6, inplace: bool = False) -> Mixture:
20652065

20662066

20672067
def combine_distributions(
2068-
dists: Sequence[Discrete | Tabular],
2068+
dists: Sequence[Discrete | Tabular | Mixture],
20692069
probs: Sequence[float]
20702070
):
20712071
"""Combine distributions with specified probabilities
20722072
20732073
This function can be used to combine multiple instances of
2074-
:class:`~openmc.stats.Discrete` and `~openmc.stats.Tabular`. Multiple
2075-
discrete distributions are merged into a single distribution and the
2076-
remainder of the distributions are put into a :class:`~openmc.stats.Mixture`
2077-
distribution.
2074+
:class:`~openmc.stats.Discrete`, :class:`~openmc.stats.Tabular` and
2075+
:class:`~openmc.stats.Mixture` of them. Multiple discrete distributions are
2076+
merged into a single distribution and the remainder of the distributions are
2077+
put into a :class:`~openmc.stats.Mixture` distribution.
20782078
20792079
.. versionadded:: 0.13.1
20802080
20812081
Parameters
20822082
----------
2083-
dists : sequence of openmc.stats.Discrete or openmc.stats.Tabular
2083+
dists : sequence of openmc.stats.Discrete, openmc.stats.Tabular, or openmc.stats.Mixture
20842084
Distributions to combine
20852085
probs : sequence of float
20862086
Probability (or intensity) of each distribution
20872087
20882088
"""
2089+
new_probs = []
2090+
new_dists = []
20892091
for i, dist in enumerate(dists):
2090-
cv.check_type(f'dists[{i}]', dist, (Discrete, Tabular))
2092+
cv.check_type(f'dists[{i}]', dist, (Discrete, Tabular, Mixture))
20912093
cv.check_type(f'probs[{i}]', probs[i], Real)
20922094
cv.check_greater_than(f'probs[{i}]', probs[i], 0.0)
2095+
if isinstance(dist, Mixture):
2096+
if dist.bias is not None:
2097+
warn("A Mixture distribution with a bias specified was passed "
2098+
"to combine_distributions. The bias will be discarded "
2099+
"during flattening.")
2100+
for j, d in enumerate(dist.distribution):
2101+
cv.check_type(f'dists[{i}].distribution[{j}]', d, (Discrete, Tabular))
2102+
new_probs.append(probs[i]*dist.probability[j])
2103+
new_dists.append(d)
2104+
else:
2105+
new_probs.append(probs[i])
2106+
new_dists.append(dist)
2107+
2108+
probs = new_probs
2109+
dists = new_dists
20932110

20942111
# Get list of discrete/continuous distribution indices
20952112
discrete_index = [i for i, d in enumerate(dists) if isinstance(d, Discrete)]

tests/unit_tests/test_stats.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,23 @@ def test_combine_distributions():
835835
assert isinstance(mixed, openmc.stats.Mixture)
836836
assert len(mixed.distribution) == 2
837837
assert len(mixed.probability) == 2
838+
assert mixed == openmc.stats.combine_distributions([mixed], [1.0])
839+
840+
# Mixture combined with another distribution: probabilities should be
841+
# correctly scaled when the Mixture is flattened
842+
d_a = openmc.stats.delta_function(1.0)
843+
d_b = openmc.stats.delta_function(2.0)
844+
m = openmc.stats.Mixture([0.3, 0.7], [d_a, d_b])
845+
extra = openmc.stats.delta_function(3.0)
846+
result = openmc.stats.combine_distributions([m, extra], [0.5, 0.5])
847+
assert isinstance(result, openmc.stats.Discrete)
848+
assert result.x == pytest.approx([1.0, 2.0, 3.0])
849+
assert result.p == pytest.approx([0.5*0.3, 0.5*0.7, 0.5])
850+
851+
# Passing a Mixture with a bias should warn that the bias is dropped
852+
biased_m = openmc.stats.Mixture([0.5, 0.5], [d_a, d_b], bias=[0.8, 0.2])
853+
with pytest.warns(UserWarning, match='bias'):
854+
openmc.stats.combine_distributions([biased_m], [1.0])
838855

839856
# Single tabular returns a tabular distribution with scaled probabilities
840857
t_single = openmc.stats.Tabular([0.0, 1.0], [2.0, 0.0])

0 commit comments

Comments
 (0)