Skip to content

Fix a bug in openmc.data.combine_distributions #3445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 87 additions & 32 deletions openmc/stats/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def from_xml_element(cls, elem: ET.Element):
def merge(
cls,
dists: Sequence[Discrete],
probs: Sequence[int]
probs: Sequence[float]
):
"""Merge multiple discrete distributions into a single distribution

Expand Down Expand Up @@ -830,6 +830,36 @@ def from_xml_element(cls, elem: ET.Element):
params = get_text(elem, 'parameters').split()
return cls(*map(float, params))

@classmethod
def merge(
cls,
dists: Sequence[Normal],
probs: Sequence[float]
):
"""Merge multiple normal distributions into a single distribution

Parameters
----------
dists : iterable of openmc.stats.Normal
Normal distributions to combine
probs : iterable of float
Probability of each distribution

Returns
-------
openmc.stats.Normal
Combined normal distribution

"""
if len(dists) != len(probs):
raise ValueError("Number of distributions and probabilities must match.")
means = np.array([d.mean_value for d in dists])
stds = np.array([d.std_dev for d in dists])
probs = np.asarray(probs)
combined_mean = np.dot(probs,means)
combined_std = np.sqrt(np.dot(probs**2,stds**2))
return cls(combined_mean, combined_std)


def muir(e0: float, m_rat: float, kt: float):
"""Generate a Muir energy spectrum
Expand Down Expand Up @@ -1324,6 +1354,59 @@ def from_xml_element(cls, elem: ET.Element):

return cls(probability, distribution)

@classmethod
def merge(
cls,
dists: Sequence[Univariate],
probs: Sequence[float]
):
"""Merge multiple distributions into a single distribution

Parameters
----------
dists : iterable of openmc.stats.Univariate
Univariate distributions to combine
probs : iterable of float
Probability of each distribution

Returns
-------
openmc.stats.Univariate
Combined distribution

"""
if len(dists) != len(probs):
raise ValueError("Number of distributions and probabilities must match.")

discrete = [[],[]]
normal = [[],[]]
others = [[],[]]
for p,d in zip(probs,dists):
if isinstance(d, Mixture):
iterator = zip(p * m.probability, m.distribution)
else:
iterator = [(p,d)]
for prob,dist in iterator:
if isinstance(dist,Discrete):
discrete[0].append(dist)
discrete[1].append(prob)
elif isinstance(dist,Normal):
normal[0].append(dist)
normal[1].append(prob)
else:
others[0].append(dist)
others[1].append(prob)
if discrete[0]:
others[0].append(Discrete.merge(*discrete))
others[1].append(1.0)
if normal[0]:
others[0].append(Normal.merge(*normal))
others[1].append(1.0)
if others[1]==[1.0]:
return others[0][0]
else:
return cls(others[1], others[0])

def integral(self):
"""Return integral of the distribution

Expand Down Expand Up @@ -1406,8 +1489,8 @@ def combine_distributions(
"""Combine distributions with specified probabilities

This function can be used to combine multiple instances of
:class:`~openmc.stats.Discrete` and `~openmc.stats.Tabular`. Multiple
discrete distributions are merged into a single distribution and the
:class:`~openmc.stats.Univariate`. Multiple
discrete or normal distributions are merged into a single distribution and the
remainder of the distributions are put into a :class:`~openmc.stats.Mixture`
distribution.

Expand All @@ -1421,32 +1504,4 @@ def combine_distributions(
Probability (or intensity) of each distribution

"""
# Get copy of distribution list so as not to modify the argument
dist_list = deepcopy(dists)

# Get list of discrete/continuous distribution indices
discrete_index = [i for i, d in enumerate(dist_list) if isinstance(d, Discrete)]
cont_index = [i for i, d in enumerate(dist_list) if isinstance(d, Tabular)]

# Apply probabilites to continuous distributions
for i in cont_index:
dist = dist_list[i]
dist._p *= probs[i]

if discrete_index:
# Create combined discrete distribution
dist_discrete = [dist_list[i] for i in discrete_index]
discrete_probs = [probs[i] for i in discrete_index]
combined_dist = Discrete.merge(dist_discrete, discrete_probs)

# Replace multiple discrete distributions with merged
for idx in reversed(discrete_index):
dist_list.pop(idx)
dist_list.append(combined_dist)

# Combine discrete and continuous if present
if len(dist_list) > 1:
probs = [1.0]*len(dist_list)
dist_list[:] = [Mixture(probs, dist_list.copy())]

return dist_list[0]
return Mixture.merge(dists, probs)
9 changes: 9 additions & 0 deletions tests/unit_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,15 @@ def test_normal():
assert_sample_mean(samples, mean)


def test_combine_normal():
norm1=openmc.stats.Normal(mean_value=10, std_dev=1)
norm2=openmc.stats.Normal(mean_value=1, std_dev=1)
combined=openmc.stats.combine_distributions([norm1,norm2],probs=[0.1,0.9])
assert isinstance(combined, openmc.stats.Normal)
assert combined.mean_value == pytest.approx(1.9)
assert combined.std_dev**2 == pytest.approx(0.82)


def test_muir():
mean = 10.0
mass = 5.0
Expand Down