Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/usersguide/decay_sources.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,5 @@ relevant tallies. This can be done with the aid of the
dose_tally = sp.get_tally(name='dose tally')

# Apply time correction factors
tally = d1s.apply_time_correction(dose_tally, factors, time_index)
tally = d1s.apply_time_correction(dose_tally, factors, time_indexes)

113 changes: 64 additions & 49 deletions openmc/deplete/d1s.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from openmc.data import half_life
from .abc import _normalize_timesteps
from .chain import Chain, _get_chain
from ..checkvalue import PathLike
from ..checkvalue import PathLike, check_iterable_type


def get_radionuclides(model: openmc.Model, chain_file: PathLike | Chain | None = None) -> list[str]:
Expand Down Expand Up @@ -124,12 +124,12 @@ def time_correction_factors(
def apply_time_correction(
tally: openmc.Tally,
time_correction_factors: dict[str, np.ndarray],
index: int = -1,
indexes: Sequence[int] = [-1],
sum_nuclides: bool = True
) -> openmc.Tally:
) -> list[openmc.Tally]:
"""Apply time correction factors to a tally.

This function applies the time correction factors at the given index to a
This function applies the time correction factors at the given index(es) to a
tally that contains a :class:`~openmc.ParentNuclideFilter`. When
`sum_nuclides` is True, values over all parent nuclides will be summed,
leaving a single value for each filter combination.
Expand All @@ -140,19 +140,24 @@ def apply_time_correction(
Tally to apply the time correction factors to
time_correction_factors : dict
Time correction factors as returned by :func:`time_correction_factors`
index : int, optional
Index of the time of interest. If N timesteps are provided in
indexes : sequence of int, optional
Sequence of time indices of interest. If N timesteps are provided in
:func:`time_correction_factors`, there are N + 1 times to select from.
The default is -1 which corresponds to the final time.
The default is [-1] which corresponds to the final time. For a single
time index, pass [index].
sum_nuclides : bool
Whether to sum over the parent nuclides

Returns
-------
openmc.Tally
Derived tally with time correction factors applied
list of openmc.Tally
List of derived tallies with time correction factors applied, one for
each index in the input sequence.

"""

check_iterable_type('indexes', indexes, int)

# Make sure the tally contains a ParentNuclideFilter
for i_filter, filter in enumerate(tally.filters):
if isinstance(filter, openmc.ParentNuclideFilter):
Expand All @@ -162,49 +167,59 @@ def apply_time_correction(

# Get list of radionuclides based on tally filter
radionuclides = [str(x) for x in tally.filters[i_filter].bins]
tcf = np.array([time_correction_factors[x][index] for x in radionuclides])

# Create copy of tally
new_tally = deepcopy(tally)
n_radionuclides = len(radionuclides)

# Extract TCF values for all requested indices at once
# Shape: (n_indices, n_radionuclides)
tcf_matrix = np.array([[time_correction_factors[nuc][idx] for nuc in radionuclides]
for idx in indexes])

# Determine number of bins in other filters
_, n_nuclides, n_scores = tally.shape
n_bins_before = prod([f.num_bins for f in tally.filters[:i_filter]])
n_bins_after = prod([f.num_bins for f in tally.filters[i_filter + 1:]])

# Reshape sum and sum_sq, apply TCF, and sum along that axis
_, n_nuclides, n_scores = new_tally.shape
n_radionuclides = len(radionuclides)
shape = (n_bins_before, n_radionuclides, n_bins_after, n_nuclides, n_scores)
tally_sum = new_tally.sum.reshape(shape)
tally_sum_sq = new_tally.sum_sq.reshape(shape)

# Apply TCF, broadcasting to the correct dimensions
tcf.shape = (1, -1, 1, 1, 1)
new_tally._sum = tally_sum * tcf
new_tally._sum_sq = tally_sum_sq * (tcf*tcf)
new_tally._mean = None
new_tally._std_dev = None

shape = (-1, n_nuclides, n_scores)

if sum_nuclides:
# Query the mean and standard deviation
mean = new_tally.mean
std_dev = new_tally.std_dev

# Sum over parent nuclides (note that when combining different bins for
# parent nuclide, we can't work directly on sum_sq)
new_tally._mean = mean.sum(axis=1).reshape(shape)
new_tally._std_dev = np.linalg.norm(std_dev, axis=1).reshape(shape)
new_tally._derived = True

# Remove ParentNuclideFilter
new_tally.filters.pop(i_filter)
else:
new_tally._sum.shape = shape
new_tally._sum_sq.shape = shape

return new_tally

# Original tally shape for reshaping
shape_5d = (n_bins_before, n_radionuclides, n_bins_after, n_nuclides, n_scores)
final_shape = (-1, n_nuclides, n_scores)

# Get original tally data once
tally_sum = tally.sum.reshape(shape_5d)
tally_sum_sq = tally.sum_sq.reshape(shape_5d)

# Process all indices efficiently
results = []
for i, idx in enumerate(indexes):

new_tally = deepcopy(tally)

# Get TCF for this specific index
tcf = tcf_matrix[i].reshape(1, -1, 1, 1, 1)

# Apply corrections
new_tally._sum = tally_sum * tcf
new_tally._sum_sq = tally_sum_sq * (tcf * tcf)
new_tally._mean = None
new_tally._std_dev = None

if sum_nuclides:
# Query the mean and standard deviation
mean = new_tally.mean
std_dev = new_tally.std_dev

# Sum over parent nuclides
new_tally._mean = mean.sum(axis=1).reshape(final_shape)
new_tally._std_dev = np.linalg.norm(std_dev, axis=1).reshape(final_shape)
new_tally._derived = True

# Remove ParentNuclideFilter
new_tally.filters.pop(i_filter)
else:
new_tally._sum.shape = final_shape
new_tally._sum_sq.shape = final_shape

results.append(new_tally)

return results


def prepare_tallies(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_d1s.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ def test_apply_time_correction(run_in_tmpdir):
flux = tally.mean.flatten()

# Apply TCF and make sure results are consistent
result = d1s.apply_time_correction(tally, factors, sum_nuclides=False)
result = d1s.apply_time_correction(tally, factors, sum_nuclides=False)[0]
tcf = np.array([factors[nuc][-1] for nuc in nuclides])
assert result.mean.flatten() == pytest.approx(tcf * flux)

# Make sure summed results match a manual sum
result_summed = d1s.apply_time_correction(tally, factors)
result_summed = d1s.apply_time_correction(tally, factors)[0]
assert result_summed.mean.flatten()[0] == pytest.approx(result.mean.sum())

# Make sure various tally methods work
Expand Down