66import importlib
77import itertools
88import os
9- import typing
109import warnings
1110from copy import deepcopy
1211from pathlib import Path
13- from typing import TYPE_CHECKING
12+ from typing import TYPE_CHECKING , cast
1413
1514import numpy as np
1615import orjson
2322from pymatgen .electronic_structure .core import Spin
2423
2524if TYPE_CHECKING :
25+ from collections .abc import Iterator
26+ from typing import Any , ClassVar , TextIO
27+
2628 from numpy .typing import ArrayLike , NDArray
27- from typing_extensions import Any , Self
29+ from typing_extensions import Self
2830
2931 from pymatgen .core .structure import IStructure
32+ from pymatgen .util .typing import PathLike
3033
3134
3235class VolumetricData (MSONable ):
@@ -62,7 +65,7 @@ class VolumetricData(MSONable):
6265 def __init__ (
6366 self ,
6467 structure : Structure | IStructure ,
65- data : dict [str , np . ndarray ],
68+ data : dict [str , NDArray ],
6669 distance_matrix : dict | None = None ,
6770 data_aug : dict [str , NDArray ] | None = None ,
6871 ) -> None :
@@ -81,15 +84,15 @@ def __init__(
8184 (typically augmentation charges)
8285 """
8386 self .structure = structure
84- self .is_spin_polarized = len (data ) >= 2
85- self .is_soc = len (data ) >= 4
87+ self .is_spin_polarized : bool = len (data ) >= 2
88+ self .is_soc : bool = len (data ) >= 4
8689 # convert data to numpy arrays in case they were jsanitized as lists
87- self .data = {k : np .array (v ) for k , v in data .items ()}
90+ self .data : dict [ str , NDArray ] = {k : np .asarray (v ) for k , v in data .items ()}
8891 self .dim = self .data ["total" ].shape
8992 self .data_aug = data_aug or {}
9093 self .ngridpts = self .dim [0 ] * self .dim [1 ] * self .dim [2 ]
9194 # lazy init the spin data since this is not always needed.
92- self ._spin_data : dict [Spin , float ] = {}
95+ self ._spin_data : dict [Spin , NDArray ] = {}
9396 self ._distance_matrix = distance_matrix if distance_matrix is not None else {}
9497 self .xpoints = np .linspace (0.0 , 1.0 , num = self .dim [0 ])
9598 self .ypoints = np .linspace (0.0 , 1.0 , num = self .dim [1 ])
@@ -101,8 +104,23 @@ def __init__(
101104 )
102105 self .name = "VolumetricData"
103106
107+ def __add__ (self , other ) -> Self :
108+ return self .linear_add (other , 1.0 )
109+
110+ def __radd__ (self , other ) -> Self :
111+ if other == 0 or other is None :
112+ # sum() calls 0 + self first; we treat 0 as the identity element
113+ return self
114+ if isinstance (other , self .__class__ ):
115+ return self .__add__ (other )
116+
117+ raise TypeError (f"Unsupported operand type(s) for +: '{ type (other ).__name__ } ' and '{ type (self ).__name__ } '" )
118+
119+ def __sub__ (self , other ) -> Self :
120+ return self .linear_add (other , - 1.0 )
121+
104122 @property
105- def spin_data (self ):
123+ def spin_data (self ) -> dict [ Spin , NDArray ] :
106124 """The data decomposed into actual spin data as {spin: data}.
107125 Essentially, this provides the actual Spin.up and Spin.down data
108126 instead of the total and diff. Note that by definition, a
@@ -115,7 +133,7 @@ def spin_data(self):
115133 self ._spin_data = spin_data
116134 return self ._spin_data
117135
118- def get_axis_grid (self , ind ) :
136+ def get_axis_grid (self , ind : int ) -> list [ float ] :
119137 """Get the grid for a particular axis.
120138
121139 Args:
@@ -126,21 +144,6 @@ def get_axis_grid(self, ind):
126144 lengths = self .structure .lattice .abc
127145 return [i / num_pts * lengths [ind ] for i in range (num_pts )]
128146
129- def __add__ (self , other ):
130- return self .linear_add (other , 1.0 )
131-
132- def __radd__ (self , other ):
133- if other == 0 or other is None :
134- # sum() calls 0 + self first; we treat 0 as the identity element
135- return self
136- if isinstance (other , self .__class__ ):
137- return self .__add__ (other )
138-
139- raise TypeError (f"Unsupported operand type(s) for +: '{ type (other ).__name__ } ' and '{ type (self ).__name__ } '" )
140-
141- def __sub__ (self , other ):
142- return self .linear_add (other , - 1.0 )
143-
144147 def copy (self ) -> Self :
145148 """Make a copy of VolumetricData object."""
146149 return type (self )(
@@ -150,7 +153,7 @@ def copy(self) -> Self:
150153 data_aug = self .data_aug ,
151154 )
152155
153- def linear_add (self , other , scale_factor = 1.0 ) -> VolumetricData :
156+ def linear_add (self , other , scale_factor : float = 1.0 ) -> Self :
154157 """
155158 Method to do a linear sum of volumetric objects. Used by + and -
156159 operators as well. Returns a VolumetricData object containing the
@@ -181,12 +184,12 @@ def linear_add(self, other, scale_factor=1.0) -> VolumetricData:
181184 new .data_aug = {}
182185 return new
183186
184- def scale (self , factor ) :
187+ def scale (self , factor : float ) -> None :
185188 """Scale the data in place by a factor."""
186189 for k in self .data :
187190 self .data [k ] = np .multiply (self .data [k ], factor )
188191
189- def value_at (self , x , y , z ) :
192+ def value_at (self , x : float , y : float , z : float ) -> float :
190193 """Get a data value from self.data at a given point (x, y, z) in terms
191194 of fractional lattice parameters. Will be interpolated using a
192195 RegularGridInterpolator on self.data if (x, y, z) is not in the original
@@ -227,7 +230,7 @@ def linear_slice(self, p1: ArrayLike, p2: ArrayLike, n=100):
227230 z_pts = np .linspace (p1 [2 ], p2 [2 ], num = n )
228231 return [self .value_at (x_pts [i ], y_pts [i ], z_pts [i ]) for i in range (n )]
229232
230- def get_integrated_diff (self , ind , radius , nbins = 1 ) :
233+ def get_integrated_diff (self , ind : int , radius : float , nbins : int = 1 ) -> NDArray :
231234 """Get integrated difference of atom index ind up to radius. This can be
232235 an extremely computationally intensive process, depending on how many
233236 grid points are in the VolumetricData.
@@ -273,13 +276,13 @@ def get_integrated_diff(self, ind, radius, nbins=1):
273276 data_inds = np .rint (np .mod (list (data [inds , 0 ]), 1 ) * np .tile (a , (len (dists ), 1 ))).astype (int )
274277 vals = [self .data ["diff" ][x , y , z ] for x , y , z in data_inds ]
275278
276- hist , edges = np .histogram (dists , bins = nbins , range = [ 0 , radius ] , weights = vals )
279+ hist , edges = np .histogram (dists , bins = nbins , range = ( 0 , radius ) , weights = vals )
277280 data = np .zeros ((nbins , 2 ))
278281 data [:, 0 ] = edges [1 :]
279282 data [:, 1 ] = [sum (hist [0 : i + 1 ]) / self .ngridpts for i in range (nbins )]
280283 return data
281284
282- def get_average_along_axis (self , ind ) :
285+ def get_average_along_axis (self , ind : int ) -> NDArray :
283286 """Get the averaged total of the volumetric data a certain axis direction.
284287 For example, useful for visualizing Hartree Potentials from a LOCPOT
285288 file.
@@ -300,7 +303,7 @@ def get_average_along_axis(self, ind):
300303 total = np .sum (np .sum (total_spin_dens , axis = 0 ), 0 )
301304 return total / ng [(ind + 1 ) % 3 ] / ng [(ind + 2 ) % 3 ]
302305
303- def to_hdf5 (self , filename ) :
306+ def to_hdf5 (self , filename : PathLike ) -> None :
304307 """Write the VolumetricData to a HDF5 format, which is a highly optimized
305308 format for reading storing large data. The mapping of the VolumetricData
306309 to this file format is as follows:
@@ -318,7 +321,7 @@ def to_hdf5(self, filename):
318321 """
319322 import h5py
320323
321- with h5py .File (filename , mode = "w" ) as file :
324+ with h5py .File (str ( filename ) , mode = "w" ) as file :
322325 ds = file .create_dataset ("lattice" , (3 , 3 ), dtype = "float" )
323326 ds [...] = self .structure .lattice .matrix
324327 ds = file .create_dataset ("Z" , (len (self .structure .species ),), dtype = "i" )
@@ -336,7 +339,7 @@ def to_hdf5(self, filename):
336339 file .attrs ["structure_json" ] = orjson .dumps (self .structure .as_dict ()).decode ()
337340
338341 @classmethod
339- def from_hdf5 (cls , filename : str , ** kwargs ) -> VolumetricData :
342+ def from_hdf5 (cls , filename : PathLike , ** kwargs ) -> Self :
340343 """
341344 Reads VolumetricData from HDF5 file.
342345
@@ -348,15 +351,15 @@ def from_hdf5(cls, filename: str, **kwargs) -> VolumetricData:
348351 """
349352 import h5py
350353
351- with h5py .File (filename , mode = "r" ) as file :
352- data = {k : np .array (v ) for k , v in file ["vdata" ].items ()}
354+ with h5py .File (str ( filename ) , mode = "r" ) as file :
355+ data = {k : np .asarray (v ) for k , v in file ["vdata" ].items ()}
353356 data_aug = None
354357 if "vdata_aug" in file :
355- data_aug = {k : np .array (v ) for k , v in file ["vdata_aug" ].items ()}
358+ data_aug = {k : np .asarray (v ) for k , v in file ["vdata_aug" ].items ()}
356359 structure = Structure .from_dict (orjson .loads (file .attrs ["structure_json" ]))
357360 return cls (structure , data = data , data_aug = data_aug , ** kwargs ) # type:ignore[arg-type]
358361
359- def to_cube (self , filename , comment : str = "" ):
362+ def to_cube (self , filename : PathLike , comment : str = "" ) -> None :
360363 """Write the total volumetric data to a cube file format, which consists of two comment lines,
361364 a header section defining the structure IN BOHR, and the data.
362365
@@ -365,31 +368,32 @@ def to_cube(self, filename, comment: str = ""):
365368 comment (str): If provided, this will be added to the second comment line
366369 """
367370 with zopen (filename , mode = "wt" , encoding = "utf-8" ) as file :
368- file .write (f"# Cube file for { self .structure .formula } generated by Pymatgen\n " ) # type:ignore[arg-type]
369- file .write (f"# { comment } \n " ) # type:ignore[arg-type] # type:ignore[arg-type]
370- file .write (f"\t { len (self .structure )} 0.000000 0.000000 0.000000\n " ) # type:ignore[arg-type]
371+ file = cast ("TextIO" , file )
372+ file .write (f"# Cube file for { self .structure .formula } generated by Pymatgen\n " )
373+ file .write (f"# { comment } \n " )
374+ file .write (f"\t { len (self .structure )} 0.000000 0.000000 0.000000\n " )
371375
372376 for idx in range (3 ):
373377 lattice_matrix = self .structure .lattice .matrix [idx ] / self .dim [idx ] * ang_to_bohr
374378 file .write (
375- f"\t { self .dim [idx ]} { lattice_matrix [0 ]:.6f} { lattice_matrix [1 ]:.6f} { lattice_matrix [2 ]:.6f} \n " # type:ignore[arg-type]
379+ f"\t { self .dim [idx ]} { lattice_matrix [0 ]:.6f} { lattice_matrix [1 ]:.6f} { lattice_matrix [2 ]:.6f} \n "
376380 )
377381
378382 for site in self .structure :
379383 file .write (
380- f"\t { Element (site .species_string ).Z } 0.000000 " # type:ignore[arg-type]
381- f"{ ang_to_bohr * site .coords [0 ]} " # type:ignore[arg-type]
382- f"{ ang_to_bohr * site .coords [1 ]} " # type:ignore[arg-type]
383- f"{ ang_to_bohr * site .coords [2 ]} \n " # type:ignore[arg-type]
384+ f"\t { Element (site .species_string ).Z } 0.000000 "
385+ f"{ ang_to_bohr * site .coords [0 ]} "
386+ f"{ ang_to_bohr * site .coords [1 ]} "
387+ f"{ ang_to_bohr * site .coords [2 ]} \n "
384388 )
385389
386390 for idx , dat in enumerate (self .data ["total" ].flatten (), start = 1 ):
387- file .write (f"{ ' ' if dat > 0 else '' } { dat :.6e} " ) # type:ignore[arg-type]
391+ file .write (f"{ ' ' if dat > 0 else '' } { dat :.6e} " )
388392 if idx % 6 == 0 :
389- file .write ("\n " ) # type:ignore[arg-type]
393+ file .write ("\n " )
390394
391395 @classmethod
392- def from_cube (cls , filename : str | Path ) -> Self :
396+ def from_cube (cls , filename : PathLike ) -> Self :
393397 """
394398 Initialize the cube object and store the data as pymatgen objects.
395399
@@ -459,9 +463,9 @@ class PMGDir(collections.abc.Mapping):
459463 ```
460464 """
461465
462- FILE_MAPPINGS : typing . ClassVar = {
463- n : f"pymatgen.io.vasp.{ n .capitalize ()} "
464- for n in [
466+ FILE_MAPPINGS : ClassVar [ dict [ str , str ]] = {
467+ name : f"pymatgen.io.vasp.{ name .capitalize ()} "
468+ for name in (
465469 "INCAR" ,
466470 "POSCAR" ,
467471 "KPOINTS" ,
@@ -478,41 +482,31 @@ class PMGDir(collections.abc.Mapping):
478482 "PROCAR" ,
479483 "ELFCAR" ,
480484 "DYNMAT" ,
481- ]
485+ )
482486 } | {
483487 "CONTCAR" : "pymatgen.io.vasp.Poscar" ,
484488 "IBZKPT" : "pymatgen.io.vasp.Kpoints" ,
485489 "WSWQ" : "pymatgen.io.vasp.WSWQ" ,
486490 }
487491
488- def __init__ (self , dirname : str | Path ) :
492+ def __init__ (self , dirname : PathLike ) -> None :
489493 """
490494 Args:
491495 dirname: The directory containing the VASP calculation as a string or Path.
492496 """
493497 self .path = Path (dirname ).absolute ()
494498 self .reset ()
495499
496- def reset (self ):
497- """
498- Reset all loaded files and recheck the directory for files. Use this when the contents of the directory has
499- changed.
500- """
501- # Note that py3.12 has Path.walk(). But we need to use os.walk to ensure backwards compatibility for now.
502- self ._files : dict [str , Any ] = {
503- str ((Path (d ) / f ).relative_to (self .path )): None for d , _ , fnames in os .walk (self .path ) for f in fnames
504- }
505-
506- def __contains__ (self , item ):
500+ def __contains__ (self , item ) -> bool :
507501 return item in self ._files
508502
509- def __len__ (self ):
503+ def __len__ (self ) -> int :
510504 return len (self ._files )
511505
512- def __iter__ (self ):
506+ def __iter__ (self ) -> Iterator [ str ] :
513507 return iter (self ._files )
514508
515- def __getitem__ (self , item ) :
509+ def __getitem__ (self , item : str ) -> Any :
516510 if self ._files .get (item ):
517511 return self ._files .get (item )
518512 fpath = self .path / item
@@ -539,6 +533,19 @@ def __getitem__(self, item):
539533 with zopen (fpath , mode = "rt" , encoding = "utf-8" ) as f :
540534 return f .read ()
541535
536+ def __repr__ (self ) -> str :
537+ return f"PMGDir({ self .path } )"
538+
539+ def reset (self ) -> None :
540+ """
541+ Reset all loaded files and recheck the directory for files. Use this when the contents of the directory has
542+ changed.
543+ """
544+ # Note that py3.12 has Path.walk(). But we need to use os.walk to ensure backwards compatibility for now.
545+ self ._files : dict [str , Any ] = {
546+ str ((Path (d ) / f ).relative_to (self .path )): None for d , _ , fnames in os .walk (self .path ) for f in fnames
547+ }
548+
542549 def get_files_by_name (self , name : str ) -> dict [str , Any ]:
543550 """
544551 Returns all files with a given name. E.g., if you want all the OUTCAR files, set name="OUTCAR".
@@ -547,6 +554,3 @@ def get_files_by_name(self, name: str) -> dict[str, Any]:
547554 {filename: object from PMGDir[filename]}
548555 """
549556 return {f : self [f ] for f in self ._files if name in f }
550-
551- def __repr__ (self ):
552- return f"PMGDir({ self .path } )"
0 commit comments