Skip to content

Commit b1ea39b

Browse files
authored
Minor type cleanup of io.common (#4525)
* use `from typing import XXX` over `import typing` * add types for io.common * use asarray when it could already be array
1 parent 4103783 commit b1ea39b

File tree

5 files changed

+84
-79
lines changed

5 files changed

+84
-79
lines changed

src/pymatgen/io/atat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
from __future__ import annotations
44

5-
import typing
5+
from typing import TYPE_CHECKING
66

77
import numpy as np
88

99
from pymatgen.core import Lattice, Structure, get_el_sp
1010

11-
if typing.TYPE_CHECKING:
11+
if TYPE_CHECKING:
1212
from pymatgen.core.structure import IStructure
1313

1414
__author__ = "Matthew Horton"

src/pymatgen/io/common.py

Lines changed: 75 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
import importlib
77
import itertools
88
import os
9-
import typing
109
import warnings
1110
from copy import deepcopy
1211
from pathlib import Path
13-
from typing import TYPE_CHECKING
12+
from typing import TYPE_CHECKING, cast
1413

1514
import numpy as np
1615
import orjson
@@ -23,10 +22,14 @@
2322
from pymatgen.electronic_structure.core import Spin
2423

2524
if TYPE_CHECKING:
25+
from collections.abc import Iterator
26+
from typing import Any, ClassVar, TextIO
27+
2628
from numpy.typing import ArrayLike, NDArray
27-
from typing_extensions import Any, Self
29+
from typing_extensions import Self
2830

2931
from pymatgen.core.structure import IStructure
32+
from pymatgen.util.typing import PathLike
3033

3134

3235
class VolumetricData(MSONable):
@@ -62,7 +65,7 @@ class VolumetricData(MSONable):
6265
def __init__(
6366
self,
6467
structure: Structure | IStructure,
65-
data: dict[str, np.ndarray],
68+
data: dict[str, NDArray],
6669
distance_matrix: dict | None = None,
6770
data_aug: dict[str, NDArray] | None = None,
6871
) -> None:
@@ -81,15 +84,15 @@ def __init__(
8184
(typically augmentation charges)
8285
"""
8386
self.structure = structure
84-
self.is_spin_polarized = len(data) >= 2
85-
self.is_soc = len(data) >= 4
87+
self.is_spin_polarized: bool = len(data) >= 2
88+
self.is_soc: bool = len(data) >= 4
8689
# convert data to numpy arrays in case they were jsanitized as lists
87-
self.data = {k: np.array(v) for k, v in data.items()}
90+
self.data: dict[str, NDArray] = {k: np.asarray(v) for k, v in data.items()}
8891
self.dim = self.data["total"].shape
8992
self.data_aug = data_aug or {}
9093
self.ngridpts = self.dim[0] * self.dim[1] * self.dim[2]
9194
# lazy init the spin data since this is not always needed.
92-
self._spin_data: dict[Spin, float] = {}
95+
self._spin_data: dict[Spin, NDArray] = {}
9396
self._distance_matrix = distance_matrix if distance_matrix is not None else {}
9497
self.xpoints = np.linspace(0.0, 1.0, num=self.dim[0])
9598
self.ypoints = np.linspace(0.0, 1.0, num=self.dim[1])
@@ -101,8 +104,23 @@ def __init__(
101104
)
102105
self.name = "VolumetricData"
103106

107+
def __add__(self, other) -> Self:
108+
return self.linear_add(other, 1.0)
109+
110+
def __radd__(self, other) -> Self:
111+
if other == 0 or other is None:
112+
# sum() calls 0 + self first; we treat 0 as the identity element
113+
return self
114+
if isinstance(other, self.__class__):
115+
return self.__add__(other)
116+
117+
raise TypeError(f"Unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'")
118+
119+
def __sub__(self, other) -> Self:
120+
return self.linear_add(other, -1.0)
121+
104122
@property
105-
def spin_data(self):
123+
def spin_data(self) -> dict[Spin, NDArray]:
106124
"""The data decomposed into actual spin data as {spin: data}.
107125
Essentially, this provides the actual Spin.up and Spin.down data
108126
instead of the total and diff. Note that by definition, a
@@ -115,7 +133,7 @@ def spin_data(self):
115133
self._spin_data = spin_data
116134
return self._spin_data
117135

118-
def get_axis_grid(self, ind):
136+
def get_axis_grid(self, ind: int) -> list[float]:
119137
"""Get the grid for a particular axis.
120138
121139
Args:
@@ -126,21 +144,6 @@ def get_axis_grid(self, ind):
126144
lengths = self.structure.lattice.abc
127145
return [i / num_pts * lengths[ind] for i in range(num_pts)]
128146

129-
def __add__(self, other):
130-
return self.linear_add(other, 1.0)
131-
132-
def __radd__(self, other):
133-
if other == 0 or other is None:
134-
# sum() calls 0 + self first; we treat 0 as the identity element
135-
return self
136-
if isinstance(other, self.__class__):
137-
return self.__add__(other)
138-
139-
raise TypeError(f"Unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'")
140-
141-
def __sub__(self, other):
142-
return self.linear_add(other, -1.0)
143-
144147
def copy(self) -> Self:
145148
"""Make a copy of VolumetricData object."""
146149
return type(self)(
@@ -150,7 +153,7 @@ def copy(self) -> Self:
150153
data_aug=self.data_aug,
151154
)
152155

153-
def linear_add(self, other, scale_factor=1.0) -> VolumetricData:
156+
def linear_add(self, other, scale_factor: float = 1.0) -> Self:
154157
"""
155158
Method to do a linear sum of volumetric objects. Used by + and -
156159
operators as well. Returns a VolumetricData object containing the
@@ -181,12 +184,12 @@ def linear_add(self, other, scale_factor=1.0) -> VolumetricData:
181184
new.data_aug = {}
182185
return new
183186

184-
def scale(self, factor):
187+
def scale(self, factor: float) -> None:
185188
"""Scale the data in place by a factor."""
186189
for k in self.data:
187190
self.data[k] = np.multiply(self.data[k], factor)
188191

189-
def value_at(self, x, y, z):
192+
def value_at(self, x: float, y: float, z: float) -> float:
190193
"""Get a data value from self.data at a given point (x, y, z) in terms
191194
of fractional lattice parameters. Will be interpolated using a
192195
RegularGridInterpolator on self.data if (x, y, z) is not in the original
@@ -227,7 +230,7 @@ def linear_slice(self, p1: ArrayLike, p2: ArrayLike, n=100):
227230
z_pts = np.linspace(p1[2], p2[2], num=n)
228231
return [self.value_at(x_pts[i], y_pts[i], z_pts[i]) for i in range(n)]
229232

230-
def get_integrated_diff(self, ind, radius, nbins=1):
233+
def get_integrated_diff(self, ind: int, radius: float, nbins: int = 1) -> NDArray:
231234
"""Get integrated difference of atom index ind up to radius. This can be
232235
an extremely computationally intensive process, depending on how many
233236
grid points are in the VolumetricData.
@@ -273,13 +276,13 @@ def get_integrated_diff(self, ind, radius, nbins=1):
273276
data_inds = np.rint(np.mod(list(data[inds, 0]), 1) * np.tile(a, (len(dists), 1))).astype(int)
274277
vals = [self.data["diff"][x, y, z] for x, y, z in data_inds]
275278

276-
hist, edges = np.histogram(dists, bins=nbins, range=[0, radius], weights=vals)
279+
hist, edges = np.histogram(dists, bins=nbins, range=(0, radius), weights=vals)
277280
data = np.zeros((nbins, 2))
278281
data[:, 0] = edges[1:]
279282
data[:, 1] = [sum(hist[0 : i + 1]) / self.ngridpts for i in range(nbins)]
280283
return data
281284

282-
def get_average_along_axis(self, ind):
285+
def get_average_along_axis(self, ind: int) -> NDArray:
283286
"""Get the averaged total of the volumetric data a certain axis direction.
284287
For example, useful for visualizing Hartree Potentials from a LOCPOT
285288
file.
@@ -300,7 +303,7 @@ def get_average_along_axis(self, ind):
300303
total = np.sum(np.sum(total_spin_dens, axis=0), 0)
301304
return total / ng[(ind + 1) % 3] / ng[(ind + 2) % 3]
302305

303-
def to_hdf5(self, filename):
306+
def to_hdf5(self, filename: PathLike) -> None:
304307
"""Write the VolumetricData to a HDF5 format, which is a highly optimized
305308
format for reading storing large data. The mapping of the VolumetricData
306309
to this file format is as follows:
@@ -318,7 +321,7 @@ def to_hdf5(self, filename):
318321
"""
319322
import h5py
320323

321-
with h5py.File(filename, mode="w") as file:
324+
with h5py.File(str(filename), mode="w") as file:
322325
ds = file.create_dataset("lattice", (3, 3), dtype="float")
323326
ds[...] = self.structure.lattice.matrix
324327
ds = file.create_dataset("Z", (len(self.structure.species),), dtype="i")
@@ -336,7 +339,7 @@ def to_hdf5(self, filename):
336339
file.attrs["structure_json"] = orjson.dumps(self.structure.as_dict()).decode()
337340

338341
@classmethod
339-
def from_hdf5(cls, filename: str, **kwargs) -> VolumetricData:
342+
def from_hdf5(cls, filename: PathLike, **kwargs) -> Self:
340343
"""
341344
Reads VolumetricData from HDF5 file.
342345
@@ -348,15 +351,15 @@ def from_hdf5(cls, filename: str, **kwargs) -> VolumetricData:
348351
"""
349352
import h5py
350353

351-
with h5py.File(filename, mode="r") as file:
352-
data = {k: np.array(v) for k, v in file["vdata"].items()}
354+
with h5py.File(str(filename), mode="r") as file:
355+
data = {k: np.asarray(v) for k, v in file["vdata"].items()}
353356
data_aug = None
354357
if "vdata_aug" in file:
355-
data_aug = {k: np.array(v) for k, v in file["vdata_aug"].items()}
358+
data_aug = {k: np.asarray(v) for k, v in file["vdata_aug"].items()}
356359
structure = Structure.from_dict(orjson.loads(file.attrs["structure_json"]))
357360
return cls(structure, data=data, data_aug=data_aug, **kwargs) # type:ignore[arg-type]
358361

359-
def to_cube(self, filename, comment: str = ""):
362+
def to_cube(self, filename: PathLike, comment: str = "") -> None:
360363
"""Write the total volumetric data to a cube file format, which consists of two comment lines,
361364
a header section defining the structure IN BOHR, and the data.
362365
@@ -365,31 +368,32 @@ def to_cube(self, filename, comment: str = ""):
365368
comment (str): If provided, this will be added to the second comment line
366369
"""
367370
with zopen(filename, mode="wt", encoding="utf-8") as file:
368-
file.write(f"# Cube file for {self.structure.formula} generated by Pymatgen\n") # type:ignore[arg-type]
369-
file.write(f"# {comment}\n") # type:ignore[arg-type] # type:ignore[arg-type]
370-
file.write(f"\t {len(self.structure)} 0.000000 0.000000 0.000000\n") # type:ignore[arg-type]
371+
file = cast("TextIO", file)
372+
file.write(f"# Cube file for {self.structure.formula} generated by Pymatgen\n")
373+
file.write(f"# {comment}\n")
374+
file.write(f"\t {len(self.structure)} 0.000000 0.000000 0.000000\n")
371375

372376
for idx in range(3):
373377
lattice_matrix = self.structure.lattice.matrix[idx] / self.dim[idx] * ang_to_bohr
374378
file.write(
375-
f"\t {self.dim[idx]} {lattice_matrix[0]:.6f} {lattice_matrix[1]:.6f} {lattice_matrix[2]:.6f}\n" # type:ignore[arg-type]
379+
f"\t {self.dim[idx]} {lattice_matrix[0]:.6f} {lattice_matrix[1]:.6f} {lattice_matrix[2]:.6f}\n"
376380
)
377381

378382
for site in self.structure:
379383
file.write(
380-
f"\t {Element(site.species_string).Z} 0.000000 " # type:ignore[arg-type]
381-
f"{ang_to_bohr * site.coords[0]} " # type:ignore[arg-type]
382-
f"{ang_to_bohr * site.coords[1]} " # type:ignore[arg-type]
383-
f"{ang_to_bohr * site.coords[2]} \n" # type:ignore[arg-type]
384+
f"\t {Element(site.species_string).Z} 0.000000 "
385+
f"{ang_to_bohr * site.coords[0]} "
386+
f"{ang_to_bohr * site.coords[1]} "
387+
f"{ang_to_bohr * site.coords[2]} \n"
384388
)
385389

386390
for idx, dat in enumerate(self.data["total"].flatten(), start=1):
387-
file.write(f"{' ' if dat > 0 else ''}{dat:.6e} ") # type:ignore[arg-type]
391+
file.write(f"{' ' if dat > 0 else ''}{dat:.6e} ")
388392
if idx % 6 == 0:
389-
file.write("\n") # type:ignore[arg-type]
393+
file.write("\n")
390394

391395
@classmethod
392-
def from_cube(cls, filename: str | Path) -> Self:
396+
def from_cube(cls, filename: PathLike) -> Self:
393397
"""
394398
Initialize the cube object and store the data as pymatgen objects.
395399
@@ -459,9 +463,9 @@ class PMGDir(collections.abc.Mapping):
459463
```
460464
"""
461465

462-
FILE_MAPPINGS: typing.ClassVar = {
463-
n: f"pymatgen.io.vasp.{n.capitalize()}"
464-
for n in [
466+
FILE_MAPPINGS: ClassVar[dict[str, str]] = {
467+
name: f"pymatgen.io.vasp.{name.capitalize()}"
468+
for name in (
465469
"INCAR",
466470
"POSCAR",
467471
"KPOINTS",
@@ -478,41 +482,31 @@ class PMGDir(collections.abc.Mapping):
478482
"PROCAR",
479483
"ELFCAR",
480484
"DYNMAT",
481-
]
485+
)
482486
} | {
483487
"CONTCAR": "pymatgen.io.vasp.Poscar",
484488
"IBZKPT": "pymatgen.io.vasp.Kpoints",
485489
"WSWQ": "pymatgen.io.vasp.WSWQ",
486490
}
487491

488-
def __init__(self, dirname: str | Path):
492+
def __init__(self, dirname: PathLike) -> None:
489493
"""
490494
Args:
491495
dirname: The directory containing the VASP calculation as a string or Path.
492496
"""
493497
self.path = Path(dirname).absolute()
494498
self.reset()
495499

496-
def reset(self):
497-
"""
498-
Reset all loaded files and recheck the directory for files. Use this when the contents of the directory has
499-
changed.
500-
"""
501-
# Note that py3.12 has Path.walk(). But we need to use os.walk to ensure backwards compatibility for now.
502-
self._files: dict[str, Any] = {
503-
str((Path(d) / f).relative_to(self.path)): None for d, _, fnames in os.walk(self.path) for f in fnames
504-
}
505-
506-
def __contains__(self, item):
500+
def __contains__(self, item) -> bool:
507501
return item in self._files
508502

509-
def __len__(self):
503+
def __len__(self) -> int:
510504
return len(self._files)
511505

512-
def __iter__(self):
506+
def __iter__(self) -> Iterator[str]:
513507
return iter(self._files)
514508

515-
def __getitem__(self, item):
509+
def __getitem__(self, item: str) -> Any:
516510
if self._files.get(item):
517511
return self._files.get(item)
518512
fpath = self.path / item
@@ -539,6 +533,19 @@ def __getitem__(self, item):
539533
with zopen(fpath, mode="rt", encoding="utf-8") as f:
540534
return f.read()
541535

536+
def __repr__(self) -> str:
537+
return f"PMGDir({self.path})"
538+
539+
def reset(self) -> None:
540+
"""
541+
Reset all loaded files and recheck the directory for files. Use this when the contents of the directory has
542+
changed.
543+
"""
544+
# Note that py3.12 has Path.walk(). But we need to use os.walk to ensure backwards compatibility for now.
545+
self._files: dict[str, Any] = {
546+
str((Path(d) / f).relative_to(self.path)): None for d, _, fnames in os.walk(self.path) for f in fnames
547+
}
548+
542549
def get_files_by_name(self, name: str) -> dict[str, Any]:
543550
"""
544551
Returns all files with a given name. E.g., if you want all the OUTCAR files, set name="OUTCAR".
@@ -547,6 +554,3 @@ def get_files_by_name(self, name: str) -> dict[str, Any]:
547554
{filename: object from PMGDir[filename]}
548555
"""
549556
return {f: self[f] for f in self._files if name in f}
550-
551-
def __repr__(self):
552-
return f"PMGDir({self.path})"

src/pymatgen/io/cp2k/sets.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import itertools
2323
import os
24-
import typing
2524
import warnings
2625
from typing import TYPE_CHECKING, Any
2726

@@ -367,7 +366,6 @@ def __init__(
367366
if kwargs.get("validate", True):
368367
self.validate()
369368

370-
@typing.no_type_check
371369
@staticmethod
372370
def get_basis_and_potential(
373371
structure: Structure | IStructure,

src/pymatgen/io/phonopy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
import typing
5+
from typing import TYPE_CHECKING
66

77
import numpy as np
88
from monty.dev import requires
@@ -16,7 +16,7 @@
1616
from pymatgen.phonon.thermal_displacements import ThermalDisplacementMatrices
1717
from pymatgen.symmetry.bandstructure import HighSymmKpath
1818

19-
if typing.TYPE_CHECKING:
19+
if TYPE_CHECKING:
2020
from pymatgen.core.structure import IStructure
2121

2222
try:

0 commit comments

Comments
 (0)