Skip to content

Commit 60ba7b1

Browse files
add numba, x10 speedup
1 parent 5cb7071 commit 60ba7b1

File tree

3 files changed

+28
-16
lines changed

3 files changed

+28
-16
lines changed

jafit/ja.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,20 @@
66
"""
77

88
from __future__ import annotations
9+
from typing import Any, Callable
910
import itertools
1011
from logging import getLogger
1112
import dataclasses
1213
import numpy as np
1314
import numpy.typing as npt
1415

16+
try:
17+
from numba import jit
18+
except ImportError:
19+
20+
def jit(**kwargs: Any) -> Callable[[Callable], Callable]:
21+
return lambda f: f
22+
1523

1624
mu_0 = 1.2566370614359173e-6 # Vacuum permeability [henry/meter]
1725

@@ -125,6 +133,7 @@ def solve(
125133
126134
Returns sample points for the H, M, and B fields in the order of their appearance.
127135
"""
136+
c_r, M_s, a, k_p, alpha = coef.c_r, coef.M_s, coef.a, coef.k_p, coef.alpha
128137
hmb_fragments: list[list[tuple[float, float, float]]] = []
129138

130139
# noinspection PyPep8Naming
@@ -136,9 +145,9 @@ def sweep(H: float, M: float, sign: int) -> list[tuple[float, float, float]]:
136145
# Integrate using Heun's method (instead of Euler's) for better stability.
137146
dH = sign * dH_abs
138147
try:
139-
dM_dH_1 = _dM_dH(coef, H=H, M=M, direction=sign)
148+
dM_dH_1 = _dM_dH(c_r=c_r, M_s=M_s, a=a, k_p=k_p, alpha=alpha, H=H, M=M, direction=sign)
140149
M_1 = M + dM_dH_1 * dH
141-
dM_dH_2 = _dM_dH(coef, H=H + dH, M=M_1, direction=sign)
150+
dM_dH_2 = _dM_dH(c_r=c_r, M_s=M_s, a=a, k_p=k_p, alpha=alpha, H=H + dH, M=M_1, direction=sign)
142151
M_2 = M + 0.5 * (dM_dH_1 + dM_dH_2) * dH
143152
except (FloatingPointError, ZeroDivisionError) as ex:
144153
if idx < 10:
@@ -193,40 +202,42 @@ def sweep(H: float, M: float, sign: int) -> list[tuple[float, float, float]]:
193202

194203

195204
# noinspection PyPep8Naming
196-
def _dM_dH(coef: Coef, *, H: float, M: float, direction: int) -> float:
205+
@jit(nogil=True, nopython=True)
206+
def _dM_dH(*, c_r: float, M_s: float, a: float, k_p: float, alpha: float, H: float, M: float, direction: int) -> float:
207+
# noinspection PyTypeChecker
197208
"""
198209
Evaluates the magnetic susceptibility derivative at the given point of the M(H) curve.
199210
The result is sensitive to the sign of the H change; the direction is defined as sign(dH).
200211
This implements the model described in "Jiles–Atherton Magnetic Hysteresis Parameters Identification", Pop et al.
201212
202-
>>> fun = lambda H, M, d: _dM_dH(COEF_COMSOL_JA_MATERIAL, H=H, M=M, direction=d)
213+
>>> fun = lambda H, M, d: _dM_dH(*dataclasses.asdict(COEF_COMSOL_JA_MATERIAL), H=H, M=M, direction=d)
203214
>>> assert np.isclose(fun(0, 0, +1), fun(0, 0, -1))
204215
>>> assert np.isclose(fun(+1, 0, +1), fun(-1, 0, -1))
205216
>>> assert np.isclose(fun(-1, 0, +1), fun(+1, 0, -1))
206217
>>> assert np.isclose(fun(-1, 0.8e6, +1), fun(+1, -0.8e6, -1))
207218
>>> assert np.isclose(fun(+1, 0.8e6, +1), fun(-1, -0.8e6, -1))
208219
"""
209-
if direction not in (-1, +1):
210-
raise ValueError(f"Invalid direction: {direction}")
211-
212-
H_e = H + coef.alpha * M
213-
M_an = coef.M_s * _langevin(H_e / coef.a)
214-
dM_an_dH_e = coef.M_s / coef.a * _dL_dx(H_e / coef.a)
215-
dM_irr_dH = (M_an - M) / (coef.k_p * direction * (1 - coef.c_r) - coef.alpha * (M_an - M))
216-
return (coef.c_r * dM_an_dH_e + (1 - coef.c_r) * dM_irr_dH) / (1 - coef.alpha * coef.c_r)
220+
assert direction in (-1, +1)
221+
H_e = H + alpha * M
222+
M_an = M_s * _langevin(H_e / a)
223+
dM_an_dH_e = M_s / a * _dL_dx(H_e / a)
224+
dM_irr_dH = (M_an - M) / (k_p * direction * (1 - c_r) - alpha * (M_an - M))
225+
return (c_r * dM_an_dH_e + (1 - c_r) * dM_irr_dH) / (1 - alpha * c_r) # type: ignore
217226

218227

228+
@jit(nogil=True, nopython=True)
219229
def _langevin(x: float) -> float:
220230
"""
221231
L(x) = coth(x) - 1/x
222232
For tensors, the function is applied element-wise.
223233
"""
224234
if np.abs(x) < _EPSILON: # For very small |x|, use the series expansion ~ x/3
225235
return x / 3.0
226-
return float(1.0 / np.tanh(x) - 1.0 / x)
236+
return 1.0 / np.tanh(x) - 1.0 / x # type: ignore
227237

228238

229239
# noinspection PyPep8Naming
240+
@jit(nogil=True, nopython=True)
230241
def _dL_dx(x: float) -> float:
231242
"""
232243
Derivative of Langevin L(x) = coth(x) - 1/x.
@@ -236,7 +247,7 @@ def _dL_dx(x: float) -> float:
236247
return 1.0 / 3.0
237248
# exact expression: -csch^2(x) + 1/x^2
238249
# csch^2(x) = 1 / sinh^2(x)
239-
return float(-1.0 / (np.sinh(x) ** 2) + 1.0 / (x**2))
250+
return -1.0 / (np.sinh(x) ** 2) + 1.0 / (x**2) # type: ignore
240251

241252

242253
_logger = getLogger(__name__)

jafit/jafit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def visualize(
179179
) -> None:
180180
import matplotlib.pyplot as plt
181181

182-
fig, (ax_m, ax_b) = plt.subplots(2, 1, figsize=(12, 10), sharex="all")
182+
fig, (ax_m, ax_b) = plt.subplots(2, 1, figsize=(12, 10), sharex="all") # type: ignore
183183
try:
184184
# Plot the curves predicted by the JA model
185185
for i, fragment in enumerate(sol.H_M_B_segments, start=1):
@@ -261,6 +261,7 @@ def main() -> None:
261261
logging.getLogger("matplotlib").setLevel(logging.ERROR)
262262
logging.getLogger("numpy").setLevel(logging.WARNING)
263263
logging.getLogger("scipy").setLevel(logging.WARNING)
264+
logging.getLogger("numba").setLevel(logging.WARNING)
264265

265266
np.seterr(all="raise")
266267

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ show_error_context = true
2828
mypy_path = []
2929

3030
[[tool.mypy.overrides]]
31-
module = ["scipy", "scipy.*", "matplotlib"]
31+
module = ["scipy", "scipy.*", "matplotlib", "numba"]
3232
ignore_missing_imports = true
3333

3434
# -------------------------------------------------- BLACK --------------------------------------------------

0 commit comments

Comments
 (0)