Skip to content

Commit d3adca2

Browse files
Merge branch 'main' into metadata3
2 parents 7c165c7 + 6dc86f5 commit d3adca2

File tree

5 files changed

+51
-29
lines changed

5 files changed

+51
-29
lines changed

src/diffpy/labpdfproc/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def compute_cve(diffraction_data, mud, wavelength):
197197
"""
198198

199199
mu_sample_invmm = mud / 2
200-
abs_correction = Gridded_circle(n_points_on_diameter=N_POINTS_ON_DIAMETER, mu=mu_sample_invmm)
200+
abs_correction = Gridded_circle(mu=mu_sample_invmm)
201201
distances, muls = [], []
202202
for angle in TTH_GRID:
203203
abs_correction.set_distances_at_angle(angle)

src/diffpy/labpdfproc/labpdfprocapp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def main():
117117
args = load_user_info(args)
118118
args = set_input_lists(args)
119119
args.output_directory = set_output_directory(args)
120-
args.wavelength = set_wavelength(args)
120+
args = set_wavelength(args)
121121
args = load_user_metadata(args)
122122

123123
for filepath in args.input_paths:

src/diffpy/labpdfproc/tests/test_functions.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import pytest
33

4-
from diffpy.labpdfproc.functions import Gridded_circle, compute_cve
4+
from diffpy.labpdfproc.functions import Gridded_circle, apply_corr, compute_cve
55
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
66

77
params1 = [
@@ -56,27 +56,47 @@ def test_set_muls_at_angle(inputs, expected):
5656
assert actual_muls_sorted == pytest.approx(expected_muls_sorted, rel=1e-4, abs=1e-6)
5757

5858

59-
def test_compute_cve(mocker):
60-
mocker.patch("diffpy.labpdfproc.functions.N_POINTS_ON_DIAMETER", 4)
61-
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", np.array([45, 60, 90]))
62-
input_pattern = Diffraction_object(wavelength=1.54)
63-
input_pattern.insert_scattering_quantity(
64-
np.array([45, 60, 90]),
65-
np.array([2.2, 3, 4]),
59+
def _instantiate_test_do(xarray, yarray, name="test", scat_quantity="x-ray"):
60+
test_do = Diffraction_object(wavelength=1.54)
61+
test_do.insert_scattering_quantity(
62+
xarray,
63+
yarray,
6664
"tth",
67-
scat_quantity="x-ray",
68-
name="test",
65+
scat_quantity=scat_quantity,
66+
name=name,
6967
metadata={"thing1": 1, "thing2": "thing2"},
7068
)
69+
return test_do
70+
71+
72+
def test_compute_cve(mocker):
73+
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
74+
expected_cve = np.array([0.5, 0.5, 0.5])
75+
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
76+
mocker.patch("numpy.interp", return_value=expected_cve)
77+
input_pattern = _instantiate_test_do(xarray, yarray)
7178
actual_abdo = compute_cve(input_pattern, mud=1, wavelength=1.54)
72-
expected_abdo = Diffraction_object()
73-
expected_abdo.insert_scattering_quantity(
74-
np.array([45, 60, 90]),
75-
np.array([2.54253, 2.52852, 2.49717]),
76-
"tth",
77-
metadata={"thing1": 1, "thing2": "thing2"},
79+
expected_abdo = _instantiate_test_do(
80+
xarray,
81+
expected_cve,
7882
name="absorption correction, cve, for test",
79-
wavelength=1.54,
8083
scat_quantity="cve",
8184
)
8285
assert actual_abdo == expected_abdo
86+
87+
88+
def test_apply_corr(mocker):
89+
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
90+
expected_cve = np.array([0.5, 0.5, 0.5])
91+
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
92+
mocker.patch("numpy.interp", return_value=expected_cve)
93+
input_pattern = _instantiate_test_do(xarray, yarray)
94+
absorption_correction = _instantiate_test_do(
95+
xarray,
96+
expected_cve,
97+
name="absorption correction, cve, for test",
98+
scat_quantity="cve",
99+
)
100+
actual_corr = apply_corr(input_pattern, absorption_correction)
101+
expected_corr = _instantiate_test_do(xarray, np.array([1, 1, 1]))
102+
assert actual_corr == expected_corr

src/diffpy/labpdfproc/tests/test_tools.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,20 +146,21 @@ def test_set_output_directory_bad(user_filesystem):
146146

147147

148148
params2 = [
149-
([], [0.71]),
150-
(["--anode-type", "Ag"], [0.59]),
151-
(["--wavelength", "0.25"], [0.25]),
152-
(["--wavelength", "0.25", "--anode-type", "Ag"], [0.25]),
149+
([], [0.71, "Mo"]),
150+
(["--anode-type", "Ag"], [0.59, "Ag"]),
151+
(["--wavelength", "0.25"], [0.25, None]),
152+
(["--wavelength", "0.25", "--anode-type", "Ag"], [0.25, None]),
153153
]
154154

155155

156156
@pytest.mark.parametrize("inputs, expected", params2)
157157
def test_set_wavelength(inputs, expected):
158-
expected_wavelength = expected[0]
158+
expected_wavelength, expected_anode_type = expected[0], expected[1]
159159
cli_inputs = ["2.5", "data.xy"] + inputs
160160
actual_args = get_args(cli_inputs)
161-
actual_args.wavelength = set_wavelength(actual_args)
161+
actual_args = set_wavelength(actual_args)
162162
assert actual_args.wavelength == expected_wavelength
163+
assert getattr(actual_args, "anode_type", None) == expected_anode_type
163164

164165

165166
params3 = [
@@ -183,7 +184,7 @@ def test_set_wavelength_bad(inputs, msg):
183184
cli_inputs = ["2.5", "data.xy"] + inputs
184185
actual_args = get_args(cli_inputs)
185186
with pytest.raises(ValueError, match=re.escape(msg[0])):
186-
actual_args.wavelength = set_wavelength(actual_args)
187+
actual_args = set_wavelength(actual_args)
187188

188189

189190
params5 = [

src/diffpy/labpdfproc/tools.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,12 @@ def set_wavelength(args):
127127
)
128128

129129
if args.wavelength:
130-
return args.wavelength
130+
delattr(args, "anode_type")
131131
elif args.anode_type:
132-
return WAVELENGTHS[args.anode_type]
132+
args.wavelength = WAVELENGTHS[args.anode_type]
133133
else:
134-
return WAVELENGTHS["Mo"]
134+
args.wavelength = WAVELENGTHS["Mo"]
135+
return args
135136

136137

137138
def _load_key_value_pair(s):

0 commit comments

Comments
 (0)