Skip to content

Commit 2691ff8

Browse files
GuyStenpaulromano
andauthored
Fix type hinting and simplify implementation of combine_distributions (#3445)
Co-authored-by: Paul Romano <paul.k.romano@gmail.com>
1 parent 5847b0d commit 2691ff8

2 files changed

Lines changed: 40 additions & 26 deletions

File tree

openmc/stats/univariate.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def from_xml_element(cls, elem: ET.Element):
397397
def merge(
398398
cls,
399399
dists: Sequence[Discrete],
400-
probs: Sequence[int]
400+
probs: Sequence[float]
401401
):
402402
"""Merge multiple discrete distributions into a single distribution
403403
@@ -1897,7 +1897,7 @@ def clip(self, tolerance: float = 1e-6, inplace: bool = False) -> Mixture:
18971897

18981898

18991899
def combine_distributions(
1900-
dists: Sequence[Univariate],
1900+
dists: Sequence[Discrete | Tabular],
19011901
probs: Sequence[float]
19021902
):
19031903
"""Combine distributions with specified probabilities
@@ -1912,41 +1912,40 @@ def combine_distributions(
19121912
19131913
Parameters
19141914
----------
1915-
dists : iterable of openmc.stats.Univariate
1915+
dists : sequence of openmc.stats.Discrete or openmc.stats.Tabular
19161916
Distributions to combine
1917-
probs : iterable of float
1917+
probs : sequence of float
19181918
Probability (or intensity) of each distribution
19191919
19201920
"""
1921-
# Get copy of distribution list so as not to modify the argument
1922-
dist_list = deepcopy(dists)
1921+
for i, dist in enumerate(dists):
1922+
cv.check_type(f'dists[{i}]', dist, (Discrete, Tabular))
1923+
cv.check_type(f'probs[{i}]', probs[i], Real)
1924+
cv.check_greater_than(f'probs[{i}]', probs[i], 0.0)
19231925

19241926
# Get list of discrete/continuous distribution indices
1925-
discrete_index = [i for i, d in enumerate(dist_list) if isinstance(d, Discrete)]
1926-
cont_index = [i for i, d in enumerate(dist_list) if isinstance(d, Tabular)]
1927+
discrete_index = [i for i, d in enumerate(dists) if isinstance(d, Discrete)]
1928+
cont_index = [i for i, d in enumerate(dists) if isinstance(d, Tabular)]
19271929

1928-
# Apply probabilites to continuous distributions
1929-
for i in cont_index:
1930-
dist = dist_list[i]
1931-
dist._p *= probs[i]
1930+
cont_dists = [dists[i] for i in cont_index]
1931+
cont_probs = [probs[i] for i in cont_index]
19321932

19331933
if discrete_index:
19341934
# Create combined discrete distribution
1935-
dist_discrete = [dist_list[i] for i in discrete_index]
1935+
dist_discrete = [dists[i] for i in discrete_index]
19361936
discrete_probs = [probs[i] for i in discrete_index]
19371937
combined_dist = Discrete.merge(dist_discrete, discrete_probs)
1938-
1939-
# Replace multiple discrete distributions with merged
1940-
for idx in reversed(discrete_index):
1941-
dist_list.pop(idx)
1942-
dist_list.append(combined_dist)
1943-
1944-
# Combine discrete and continuous if present
1945-
if len(dist_list) > 1:
1946-
probs = [1.0]*len(dist_list)
1947-
dist_list[:] = [Mixture(probs, dist_list.copy())]
1948-
1949-
return dist_list[0]
1938+
if cont_index:
1939+
return Mixture(cont_probs + [1.0], cont_dists + [combined_dist])
1940+
else:
1941+
return combined_dist
1942+
else:
1943+
if len(cont_dists) == 1:
1944+
dist = cont_dists[0]
1945+
return Tabular(dist.x, dist.p * cont_probs[0],
1946+
dist.interpolation, bias=dist.bias)
1947+
else:
1948+
return Mixture(cont_probs, cont_dists)
19501949

19511950

19521951
def check_bias_support(parent: Univariate, bias: Univariate | None):

tests/unit_tests/test_stats.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,12 +633,27 @@ def test_combine_distributions():
633633
assert len(mixed.distribution) == 2
634634
assert len(mixed.probability) == 2
635635

636+
# Single tabular returns a tabular distribution with scaled probabilities
637+
t_single = openmc.stats.Tabular([0.0, 1.0], [2.0, 0.0])
638+
scaled = openmc.stats.combine_distributions([t_single], [0.25])
639+
assert isinstance(scaled, openmc.stats.Tabular)
640+
assert scaled.p == pytest.approx([0.5, 0.0])
641+
642+
# Mixture with biased tabular should preserve unbiased mean via weights
643+
bias = openmc.stats.Tabular([0.0, 1.0], [2.0, 0.0])
644+
t_biased = openmc.stats.Tabular([0.0, 1.0], [1.0, 1.0], bias=bias)
645+
d1 = openmc.stats.delta_function(0.0)
646+
mixed = openmc.stats.combine_distributions([t_biased, d1], [0.5, 0.5])
647+
assert isinstance(mixed, openmc.stats.Mixture)
648+
samples, weights = mixed.sample(10_000)
649+
assert_sample_mean(samples*weights, 0.25)
650+
636651
# Combine 1 discrete and 2 tabular -- the tabular distributions should
637652
# combine to produce a uniform distribution with mean 0.5. The combined
638653
# distribution should have a mean of 0.25.
639654
t1 = openmc.stats.Tabular([0., 1.], [2.0, 0.0])
640655
t2 = openmc.stats.Tabular([0., 1.], [0.0, 2.0])
641-
d1 = openmc.stats.Discrete([0.0], [1.0])
656+
d1 = openmc.stats.delta_function(0.0)
642657
combined = openmc.stats.combine_distributions([t1, t2, d1], [0.25, 0.25, 0.5])
643658
assert combined.integral() == pytest.approx(1.0)
644659

0 commit comments

Comments
 (0)