Skip to content

Commit

Permalink
Prevents refractive index dispersions from being summed (#113)
Browse files Browse the repository at this point in the history
* Adds unsummable dispersion

* Replaces UnsummableDispersion with IndexDispersion

* Lets IndexTable inherit from IndexDispersion

* Raise error on adding of tabular dispersions

* Offload add to table dispersion
  • Loading branch information
domna authored Jan 30, 2023
1 parent b9e22a8 commit e8a0789
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 29 deletions.
85 changes: 76 additions & 9 deletions src/elli/dispersions/base_dispersion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Encoding: utf-8
"""Abstract base class and utility classes for pyElli dispersion"""
from abc import ABC, abstractmethod
from typing import Union
from typing import List, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -101,13 +101,32 @@ def add(self, *args, **kwargs) -> "Dispersion":

return self

def __add__(self, other: Union[int, float, "Dispersion"]) -> "Dispersion":
def _check_valid_operand(self, other: Union[int, float, "Dispersion"]):
if not isinstance(other, (int, float, Dispersion)):
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)

def _is_non_std_dispersion(self, other: Union[int, float, "Dispersion"]) -> bool:
return isinstance(other, (IndexDispersion, dispersions.Table))

def __radd__(self, other: Union[int, float, "Dispersion"]) -> "DispersionSum":
"""Add up the dielectric function of multiple models"""
if isinstance(other, (int, float)):
return DispersionSum(self, dispersions.EpsilonInf(eps=other))
return self.__add__(other)

if not isinstance(other, Dispersion):
raise TypeError(f"Invalid type {type(other)} added to dispersion")
def __add__(self, other: Union[int, float, "Dispersion"]) -> "DispersionSum":
"""Add up the dielectric function of multiple models"""
self._check_valid_operand(other)

if self._is_non_std_dispersion(other):
return other.__add__(self)

if isinstance(other, DispersionSum):
other.dispersions.append(self)
return other

if isinstance(other, (int, float)):
return DispersionSum(self, dispersions.EpsilonInf(other))

return DispersionSum(self, other)

Expand Down Expand Up @@ -195,6 +214,36 @@ def _dict_to_str(dic):
)


class IndexDispersion(Dispersion):
"""A dispersion based on a refractive index formulation."""

@abstractmethod
def refractive_index(self, lbda: npt.ArrayLike) -> npt.NDArray:
"""Calculates the refractive index in a given wavelength window.
Args:
lbda (npt.ArrayLike): The wavelength window with unit nm.
Returns:
npt.NDArray: The refractive index for each wavelength point.
"""

def __add__(self, other: Union[int, float, "Dispersion"]) -> "DispersionSum":
self._check_valid_operand(other)

if isinstance(other, IndexDispersion):
raise NotImplementedError(
"Adding of index based dispersions is not supported yet"
)

raise TypeError(
"Cannot add refractive index and dielectric function based dispersions."
)

def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray:
return self.refractive_index(lbda) ** 2


class DispersionFactory:
"""A factory class for dispersion objects"""

Expand All @@ -220,12 +269,30 @@ def get_dispersion(identifier: str, *args, **kwargs) -> Dispersion:
class DispersionSum(Dispersion):
"""Represents a sum of two dispersions"""

single_params_template = {}
rep_params_template = {}
single_params_template: dict = {}
rep_params_template: dict = {}
dispersions: List[Dispersion]

def __init__(self, *disps: Dispersion) -> None:
super().__init__()
self.dispersions = disps
self.dispersions = list(disps)

def __add__(self, other: Union[int, float, "Dispersion"]) -> "DispersionSum":
self._check_valid_operand(other)

if self._is_non_std_dispersion(other):
return other.__add__(self)

if isinstance(other, DispersionSum):
self.dispersions += other.dispersions
return self

if isinstance(other, (int, float)):
self.dispersions.append(dispersions.EpsilonInf(eps=other))
return self

self.dispersions.append(other)
return self

def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray:
dielectric_function = sum(
Expand Down
9 changes: 4 additions & 5 deletions src/elli/dispersions/cauchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"""Cauchy dispersion."""
import numpy.typing as npt

from .base_dispersion import Dispersion
from .base_dispersion import IndexDispersion


class Cauchy(Dispersion):
class Cauchy(IndexDispersion):
r"""Cauchy dispersion.
Single parameters:
Expand All @@ -30,8 +30,8 @@ class Cauchy(Dispersion):
single_params_template = {"n0": 1.5, "n1": 0, "n2": 0, "k0": 0, "k1": 0, "k2": 0}
rep_params_template = {}

def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray:
refr_index = (
def refractive_index(self, lbda: npt.ArrayLike) -> npt.NDArray:
return (
self.single_params.get("n0")
+ 1e2 * self.single_params.get("n1") / lbda**2
+ 1e7 * self.single_params.get("n2") / lbda**4
Expand All @@ -42,4 +42,3 @@ def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray:
+ 1e7 * self.single_params.get("k2") / lbda**4
)
)
return refr_index**2
10 changes: 4 additions & 6 deletions src/elli/dispersions/cauchy_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"""Cauchy dispersion with custom exponents."""
import numpy.typing as npt

from .base_dispersion import Dispersion
from .base_dispersion import IndexDispersion


class CauchyCustomExponent(Dispersion):
class CauchyCustomExponent(IndexDispersion):
r"""Cauchy dispersion with custom exponents.
Single parameters:
Expand All @@ -24,9 +24,7 @@ class CauchyCustomExponent(Dispersion):
single_params_template = {"n0": 1.5}
rep_params_template = {"f": 0, "e": 1}

def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray:
refr_index = self.single_params.get("n0") + sum(
def refractive_index(self, lbda: npt.ArrayLike) -> npt.NDArray:
return self.single_params.get("n0") + sum(
c.get("f") * lbda ** c.get("e") for c in self.rep_params
)

return refr_index**2
9 changes: 4 additions & 5 deletions src/elli/dispersions/constant_refractive_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"""Constant refractive index."""
import numpy.typing as npt

from .base_dispersion import Dispersion
from .base_dispersion import IndexDispersion


class ConstantRefractiveIndex(Dispersion):
class ConstantRefractiveIndex(IndexDispersion):
r"""Constant refractive index.
Single parameters:
Expand All @@ -18,9 +18,8 @@ class ConstantRefractiveIndex(Dispersion):
.. math::
\varepsilon(\lambda) = \boldsymbol{n}^2
"""

single_params_template = {"n": 1}
rep_params_template = {}

def dielectric_function(self, _: npt.ArrayLike) -> npt.NDArray:
return self.single_params.get("n") ** 2
def refractive_index(self, _: npt.ArrayLike) -> npt.NDArray:
return self.single_params.get("n")
4 changes: 4 additions & 0 deletions src/elli/dispersions/table_epsilon.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Encoding: utf-8
"""Dispersion specified by a table of wavelengths (nm) and dielectric function values."""
from typing import Union
import numpy as np
import numpy.typing as npt
import scipy.interpolate
Expand Down Expand Up @@ -49,5 +50,8 @@ def __init__(self, *args, **kwargs) -> None:
kind="cubic",
)

def __add__(self, _: Union[int, float, "Dispersion"]) -> "DispersionSum":
raise NotImplementedError("Adding of tabular dispersions is not yet supported")

def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray:
return self.interpolation(lbda)
8 changes: 4 additions & 4 deletions src/elli/dispersions/table_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy.typing as npt
import scipy.interpolate

from .base_dispersion import Dispersion, InvalidParameters
from .base_dispersion import IndexDispersion, InvalidParameters


class Table(Dispersion):
class Table(IndexDispersion):
"""Dispersion specified by a table of wavelengths (nm) and refractive index values.
Please not that this model will produce errors for wavelengths outside the provided
wavelength range.
Expand Down Expand Up @@ -40,9 +40,9 @@ def __init__(self, *args, **kwargs) -> None:

self.interpolation = scipy.interpolate.interp1d(
self.single_params.get("lbda"),
self.single_params.get("n") ** 2,
self.single_params.get("n"),
kind="cubic",
)

def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray:
def refractive_index(self, lbda: npt.ArrayLike) -> npt.NDArray:
return self.interpolation(lbda)
73 changes: 73 additions & 0 deletions tests/test_dispersion_adding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Test adding of dispersions"""
import pytest
from numpy.testing import assert_array_almost_equal
from elli import Cauchy, Sellmeier
from elli.dispersions.base_dispersion import DispersionSum
from elli.dispersions.table_epsilon import TableEpsilon


def test_fail_on_adding_index_dispersion():
"""Test whether adding for an index based model fails"""
cauchy_err_str = "Adding of index based dispersions is not supported yet"
with pytest.raises(NotImplementedError) as sum_err:
_ = Cauchy() + Cauchy()

assert cauchy_err_str in str(sum_err.value)


def test_fail_on_adding_index_and_diel_dispersion():
"""Test whether the adding fails for an index based and dielectric dispersion"""

for disp in [1, Sellmeier()]:
with pytest.raises(TypeError) as sum_err:
_ = disp + Cauchy()

assert (
"Cannot add refractive index and dielectric function based dispersions."
in str(sum_err.value)
)


def test_adding_of_diel_dispersions():
"""Test if dielectric dispersions are added correctly"""

dispersion_sum = Sellmeier() + Sellmeier()

assert isinstance(dispersion_sum, DispersionSum)
assert len(dispersion_sum.dispersions) == 2

for disp in dispersion_sum.dispersions:
assert isinstance(disp, Sellmeier)

assert_array_almost_equal(
dispersion_sum.get_dielectric_df().values,
2 * Sellmeier().get_dielectric_df().values,
)


def test_flat_dispersion_sum_on_multiple_add():
"""Test whether the DispersionSum stays flat on multiple adds"""

dispersion_sum = Sellmeier() + Sellmeier() + Sellmeier()

assert isinstance(dispersion_sum, DispersionSum)
assert len(dispersion_sum.dispersions) == 3

for disp in dispersion_sum.dispersions:
assert isinstance(disp, Sellmeier)

assert_array_almost_equal(
dispersion_sum.get_dielectric_df().values,
3 * Sellmeier().get_dielectric_df().values,
)


def test_adding_of_tabular_dispersions():
"""Tests correct adding of tabular dispersions"""

with pytest.raises(NotImplementedError) as not_impl_err:
_ = TableEpsilon() + 1

assert (
str(not_impl_err.value) == "Adding of tabular dispersions is not yet supported"
)

0 comments on commit e8a0789

Please sign in to comment.