Skip to content

Commit 1302d9f

Browse files
committed
Fixes complex power for zero
1 parent e69f314 commit 1302d9f

File tree

4 files changed

+36
-6
lines changed

4 files changed

+36
-6
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "formula-dispersion"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
authors = ["Florian Dobener <[email protected]>"]
55
edition = "2018"
66

src/ast.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,26 @@ pub enum EvaluateResult {
7373
Number(Complex64),
7474
}
7575

76+
trait ComplexPower {
77+
fn pc(self, exp: Complex64) -> Complex64;
78+
}
79+
80+
impl ComplexPower for Complex64 {
81+
fn pc(self, exp: Complex64) -> Complex64 {
82+
if self.re == 0. && self.im == 0. {
83+
return Complex64::from(0.);
84+
}
85+
self.powc(exp)
86+
}
87+
}
88+
7689
impl EvaluateResult {
7790
fn power(self, other: EvaluateResult) -> EvaluateResult {
7891
use EvaluateResult::*;
7992
match (self, other) {
80-
(Number(b), Number(exp)) => EvaluateResult::Number(b.powc(exp)),
81-
(Number(b), Array(exp)) => EvaluateResult::Array(exp.map(|x| b.powc(*x))),
82-
(Array(b), Number(exp)) => EvaluateResult::Array(b.map(|x| x.powc(exp))),
93+
(Number(b), Number(exp)) => EvaluateResult::Number(b.pc(exp)),
94+
(Number(b), Array(exp)) => EvaluateResult::Array(exp.map(|x| b.pc(*x))),
95+
(Array(b), Number(exp)) => EvaluateResult::Array(b.map(|x| x.pc(exp))),
8396
(Array(b), Array(exp)) => EvaluateResult::Array(
8497
Zip::from(&b)
8598
.and(&exp)

tests/test_array.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_powc():
2929
"""Using power is working properly"""
3030

3131
parsed = parse("n = lbda ** 3", "lbda", np.array([1.0, 2.0, 3.0]), {}, {})
32-
assert_array_almost_equal(parsed, np.array([1.0**2, 8.0**2, 27.0**2]))
32+
assert_array_almost_equal(parsed, np.array([1.0, 8.0, 27.0]))
3333

3434

3535
def test_sum():
@@ -80,3 +80,20 @@ def test_fails_on_wrong_token():
8080

8181
with pytest.raises(TypeError):
8282
parse("eps = 3 * 3 * lba", "lbda", np.array([1.0, 2.0, 3.0]), {}, {})
83+
84+
85+
def test_sellmeier():
86+
formula = "n = sqrt(eps_inf + sum[A * lambda ** e1 / (lambda ** 2 - B**e2)])"
87+
# formula = "n = B ** e2"
88+
rep_params = {
89+
"A": np.array([1.0, 0.00448263]),
90+
"B": np.array([0.0, 1.108205]),
91+
"e1": np.array([0.0, 0.0]),
92+
"e2": np.array([1.0, 2.0]),
93+
}
94+
single_params = {"eps_inf": 11.67316}
95+
lbda = np.linspace(6250, 23255.814, 100)
96+
97+
assert_array_almost_equal(
98+
lbda, parse(formula, "lambda", lbda, single_params, rep_params)
99+
)

0 commit comments

Comments
 (0)