diff --git a/pyproject.toml b/pyproject.toml index c197980..f10793c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ select = [ "ISC", # flake8-implicit-str-concat "N", # pep8-naming "PL", # pylint + "PT", # flake8-pytest-style "Q", # flake8-quotes "RUF", # ruff "S", # flake8-bandit diff --git a/pyshocks/weno.py b/pyshocks/weno.py index 31eb4ec..48ed3ee 100644 --- a/pyshocks/weno.py +++ b/pyshocks/weno.py @@ -142,9 +142,7 @@ def weno_smoothness( ) -def weno_interp( - s: Stencil, u: Array, *, mode: ConvolutionType | None = None -) -> Array: +def weno_interp(s: Stencil, u: Array, *, mode: ConvolutionType | None = None) -> Array: r"""Interpolate the variable *u* at the cell faces for WENO-JS. The interpolation has the form @@ -352,7 +350,8 @@ def ss_weno_242_mask(sb: BoundaryStencil, u: Array) -> Array: if sb.bc == BoundaryType.Periodic: mask = jnp.ones((3, u.size), dtype=u.dtype) else: - assert sb.sl is not None and sb.sr is not None + assert sb.sl is not None + assert sb.sr is not None mask = jnp.ones((5, u.size + 1), dtype=u.dtype) mask = mask.at[0, :].set(0) mask = mask.at[-1, :].set(0) @@ -481,10 +480,12 @@ def ss_weno_242_interp(sb: BoundaryStencil, u: Array) -> Array: assert isinstance(sb.bc, BoundaryType) if sb.bc == BoundaryType.Periodic: - assert sb.sl is None and sb.sr is None + assert sb.sl is None + assert sb.sr is None uhat = weno_interp(sb.si, u, mode=ConvolutionType.Wrap) else: - assert sb.sl is not None and sb.sr is not None + assert sb.sl is not None + assert sb.sr is not None uhat = jnp.empty((5, u.size + 1), dtype=u.dtype) uhat = uhat.at[1:-1, 1:].set(weno_interp(sb.si, u)) diff --git a/requirements-dev.txt b/requirements-dev.txt index 9e5772e..e27a5e0 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -24,9 +24,9 @@ fonttools==4.39.4 # via matplotlib iniconfig==2.0.0 # via pytest -jax==0.4.10 +jax==0.4.11 # via pyshocks (pyproject.toml) -jaxlib==0.4.10 +jaxlib==0.4.11 # via pyshocks (pyproject.toml) kiwisolver==1.4.4 # via matplotlib @@ -38,7 +38,7 @@ matplotlib==3.7.1 # scienceplots mdurl==0.1.2 # via markdown-it-py -ml-dtypes==0.1.0 +ml-dtypes==0.2.0 # via # jax # jaxlib @@ -89,15 +89,15 @@ python-dateutil==2.8.2 # via matplotlib pyweno @ git+https://github.com/memmett/PyWENO.git # via pyshocks (pyproject.toml) -regex==2023.5.5 +regex==2023.6.3 # via sphinx-lint restructuredtext-lint==1.4.0 # via doc8 -rich==13.3.5 +rich==13.4.1 # via pyshocks (pyproject.toml) -ruff==0.0.270 +ruff==0.0.272 # via pyshocks (pyproject.toml) -scienceplots==2.0.1 +scienceplots==2.1.0 # via pyshocks (pyproject.toml) scipy==1.10.1 # via @@ -113,7 +113,7 @@ sympy==1.12 # via pyshocks (pyproject.toml) types-dataclasses==0.6.6 # via pyshocks (pyproject.toml) -typing-extensions==4.6.2 +typing-extensions==4.6.3 # via mypy wheel==0.40.0 # via pip-tools diff --git a/requirements.txt b/requirements.txt index 49b472c..e66e890 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,15 +4,15 @@ # # pip-compile --output-file=requirements.txt --resolver=backtracking pyproject.toml # -jax==0.4.10 +jax==0.4.11 # via pyshocks (pyproject.toml) -jaxlib==0.4.10 +jaxlib==0.4.11 # via pyshocks (pyproject.toml) markdown-it-py==2.2.0 # via rich mdurl==0.1.2 # via markdown-it-py -ml-dtypes==0.1.0 +ml-dtypes==0.2.0 # via # jax # jaxlib @@ -28,7 +28,7 @@ opt-einsum==3.3.0 # via jax pygments==2.15.1 # via rich -rich==13.3.5 +rich==13.4.1 # via pyshocks (pyproject.toml) scipy==1.10.1 # via diff --git a/tests/test_weno.py b/tests/test_weno.py index d9a0a81..8dfef08 100644 --- a/tests/test_weno.py +++ b/tests/test_weno.py @@ -45,7 +45,8 @@ def test_weno_smoothness_indicator_vectorization( a = rec.s.a b = rec.s.b c = rec.s.c - assert a is not None and b is not None + assert a is not None + assert b is not None nghosts = b.shape[-1] // 2 nstencils = b.shape[0] @@ -227,14 +228,16 @@ def test_weno_vs_pyweno( error_l = rnorm(grid, sl, betal) error_r = rnorm(grid, sr, betar) logger.info("error smoothness: left %.5e right %.5e", error_l, error_r) - assert error_l < 1.0e-5 and error_r < 1.0e-8 + assert error_l < 1.0e-5 + assert error_r < 1.0e-8 ulhat, urhat = reconstruct(rec, grid, BoundaryType.Dirichlet, u, u, u) error_l = rnorm(grid, ul, ulhat) error_r = rnorm(grid, ur, urhat) logger.info("error reconstruct: left %.5e right %.5e", error_l, error_r) - assert error_l < 1.0e-12 and error_r < 1.0e-12 + assert error_l < 1.0e-12 + assert error_r < 1.0e-12 # }}}