Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 149 additions & 88 deletions dptb/data/dataset/_default_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,66 +167,105 @@ def from_text_data(cls,
@classmethod
def from_ase_traj(cls,
root: str,
get_Hamiltonian = False,
get_overlap = False,
get_DM = False,
get_eigenvalues = False,
info = None):

assert not get_Hamiltonian * get_DM, "Hamiltonian and Density Matrix can only loaded one at a time, for which will occupy the same attribute in the AtomicData."

get_Hamiltonian: bool = False,
get_overlap: bool = False,
get_DM: bool = False,
get_eigenvalues: bool = False,
info: Optional[Dict] = None):
'''
Build the _TrajData instance by reading the data from the single data directory
that organized in the way compatible with the ASE.

Parameters
----------
root: str
The folder where the data is stored, including the traj file.
get_Hamiltonian: bool
Whether to load the hamiltonian blocks.
get_overlap: bool
Whether to load the overlap blocks.
get_DM: bool
Whether to load the density matrix blocks.
get_eigenvalues: bool
Whether to load the eigenvalues.
info: dict
The description of the data, may be inconsistent with the real data, in this case,
the info will be updated.
'''
assert not get_Hamiltonian * get_DM, \
"Hamiltonian and Density Matrix can only loaded one at a time, " + \
"for which will occupy the same attribute in the AtomicData."

# read the ase trajectory file...
traj_file = glob.glob(f"{root}/*.traj")
assert len(traj_file) == 1, print("only one ase trajectory file can be provided.")
traj = Trajectory(traj_file[0], 'r')
nframes = len(traj)
assert nframes > 0, print("trajectory file is empty.")
if nframes != info.get("nframes", None):
info['nframes'] = nframes
log.info(f"Number of frames ({nframes}) in trajectory file does not match the number of frames in info file.")

natoms = traj[0].positions.shape[0]
if natoms != info["natoms"]:
info["natoms"] = natoms

pbc = info.get("pbc",None)
if pbc is None:
pbc = traj[0].pbc.tolist()
info["pbc"] = pbc

if isinstance(pbc, bool):
pbc = [pbc] * 3

if pbc != traj[0].pbc.tolist():
log.warning("!! PBC setting in info file does not match the PBC setting in trajectory file, we use the one in info json. BE CAREFUL!")

positions = []
cell = []
atomic_numbers = []

for atoms in traj:
positions.append(atoms.get_positions())
assert len(traj_file) == 1, "only one ase trajectory file can be provided."

atomic_numbers, positions, cell = None, None, None
# use the context manager to avoid memory leak
with Trajectory(traj_file[0], 'r') as traj:
# there are some dimensions of importance: nframe, natom, ...
nframes = len(traj)
assert nframes > 0, "trajectory file is empty."
# if there is discrepancy between nframes in info and traj, then update the info.
if nframes != info.get("nframes", None):
info['nframes'] = nframes
log.info(f"Number of frames ({nframes}) in trajectory file does not match the info file.")

atomic_numbers.append(atoms.get_atomic_numbers())
if (np.abs(atoms.get_cell()-np.zeros([3,3]))< 1e-6).all():
cell = None
else:
cell.append(atoms.get_cell())

positions = np.array(positions)
positions = positions.reshape(nframes,natoms, 3)

if cell is not None:
cell = np.array(cell)
cell = cell.reshape(nframes,3, 3)

atomic_numbers = np.array(atomic_numbers)
atomic_numbers = atomic_numbers.reshape(nframes, natoms)

data = {}
if cell is not None:
data["cell"] = cell
data["pos"] = positions
data["atomic_numbers"] = atomic_numbers
# assuming there will not be number of atoms change within the trajectory file...
# we check, because the trajectory file does support this.
natoms = np.unique([len(atoms) for atoms in traj])
assert len(natoms) == 1, "Number of atoms in trajectory file is not consistent."
natoms = natoms[0]
# natoms = traj[0].positions.shape[0]
if natoms != info["natoms"]:
info["natoms"] = natoms
log.info(f"Number of atoms ({natoms}) in trajectory file does not match the info file.")

# handling the pbc flag
pbc = info.get("pbc", None)
if pbc is None:
# read from the trajectory...however, the same issue also exists here, the pbc may
# change along the trajectory, so we need to check it (only allow one pbc setting)
pbc = np.unique([atoms.pbc.tolist() for atoms in traj])
assert len(pbc) == 1, "PBC setting in trajectory file is not consistent."
pbc = pbc[0]
assert isinstance(pbc, list) and len(pbc) == 3, \
f"Unexpected `PBC` format: {pbc}"
info["pbc"] = pbc
# check on the value of pbc
if isinstance(pbc, bool):
pbc = [pbc] * 3
if pbc != traj[0].pbc.tolist():
log.warning("!! PBC setting in info file does not match the PBC setting in trajectory file, "
"we use the one in info json. BE CAREFUL!")

# overwrite the following three to the empty lists
atomic_numbers, positions, cell = [], [], []
for atoms in traj:
atoms: Atoms # type annotation :)

atomic_numbers.append(atoms.get_atomic_numbers())
positions.append(atoms.get_positions())
# if there is no cell information, then set it to None. However,
# there is also the case that the invalidity of cell is reflected
# by the cell being all zeros.
cell_read = atoms.get_cell()
cell.append(None if np.allclose(cell_read, np.zeros((3, 3)), atol=1e-6) else cell_read)

# the trajectory reading must be successful and not empty
assert positions
assert atomic_numbers
assert cell

# this may raise errors about the inhomogenity of the data, or the reshape failed
data = {"pos": np.array(positions).reshape(nframes, natoms, 3),
"atomic_numbers": np.array(atomic_numbers).reshape(nframes, natoms)}
assert len(cell) == nframes
if all(c is not None for c in cell):
data["cell"] = np.array(cell).reshape(nframes, 3, 3)
else:
# otherwise, we expect that all cells are None, the hybrid case is not allowed
assert all(c is None for c in cell)

return cls(root=root,
data=data,
Expand Down Expand Up @@ -318,18 +357,47 @@ def toAtomicDataList(self, idp: TypeMapper = None):

class DefaultDataset(AtomicInMemoryDataset):

def __init__(
self,
root: str,
info_files: Dict[str, Dict],
url: Optional[str] = None, # seems useless but can't be remove
include_frames: Optional[List[int]] = None, # maybe support in future
type_mapper: TypeMapper = None,
get_Hamiltonian: bool = False,
get_overlap: bool = False,
get_DM: bool = False,
get_eigenvalues: bool = False,
):
def __init__(self,
root: str,
info_files: Dict[str, Dict],
url: Optional[str] = None, # seems useless but can't be remove
include_frames: Optional[List[int]] = None, # maybe support in future
type_mapper: TypeMapper = None,
get_Hamiltonian: bool = False,
get_overlap: bool = False,
get_DM: bool = False,
get_eigenvalues: bool = False):
'''
instantiate the default dataset.

Parameters
----------
root : str
root directory of the dataset.
info_files : Dict[str, Dict]
the description of all the "valid" subfolders in the root directory, here the
"valid" means there are data files in the subfolder.
url : Optional[str], optional
not used in DeePTB. see its super class
include_frames : Optional[List[int]], optional
not used in DeePTB. see its super class
type_mapper: TypeMapper, optional
the mapping from orbpair index to reduced matrix element, see docstrings of class
OrbitalMapper for more information
get_Hamiltonian : bool, optional
whether to get the Hamiltonian, by default False
get_overlap : bool, optional
whether to get the overlap, by default False
get_DM : bool, optional
whether to get the density matrix, by default False
get_eigenvalues : bool, optional
whether to get the eigenvalues, by default False
'''
def build_data(pos_typ: str, **kwargs):
builder = {'ase': _TrajData.from_ase_traj}
build_func = builder.get(pos_typ, _TrajData.from_text_data)
return build_func(**kwargs)

self.root = root
self.url = url
self.info_files = info_files
Expand All @@ -345,24 +413,16 @@ def __init__(
# get the info here
info = info_files[file]
# assert "AtomicData_options" in info
assert "r_max" in info
assert "pbc" in info
pbc = info["pbc"]
if info["pos_type"] == "ase":
subdata = _TrajData.from_ase_traj(os.path.join(self.root, file),
get_Hamiltonian,
get_overlap,
get_DM,
get_eigenvalues,
info=info)
else:
subdata = _TrajData.from_text_data(os.path.join(self.root, file),
get_Hamiltonian,
get_overlap,
get_DM,
get_eigenvalues,
info=info)
self.raw_data.append(subdata)
assert all(attr in info for attr in ["r_max", "pbc"])
pbc = info["pbc"] # not used?
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused variable assignment.

Static analysis correctly identifies that pbc is assigned but never used. The comment acknowledges this uncertainty.

Apply this diff:

         assert all(attr in info for attr in ["r_max", "pbc"])
-        pbc = info["pbc"] # not used?
         self.raw_data.append(
🧰 Tools
🪛 Ruff (0.14.7)

417-417: Local variable pbc is assigned to but never used

Remove assignment to unused variable pbc

(F841)

🤖 Prompt for AI Agents
In dptb/data/dataset/_default_dataset.py around line 417 the code assigns pbc =
info["pbc"] but never uses it; remove the unused variable assignment (and any
accompanying comment that questions its usage) so no unused local remains, or if
pbc is intended to be used, replace the removal by using it appropriately—most
likely simply delete the pbc = info["pbc"] line.

self.raw_data.append(
build_data(pos_typ=info["pos_type"],
root=os.path.join(self.root, file),
get_Hamiltonian=get_Hamiltonian,
get_overlap=get_overlap,
get_DM=get_DM,
get_eigenvalues=get_eigenvalues,
info=info))

# The AtomicData_options is never used here.
# Because we always return a list of AtomicData object in `get_data()`.
Expand All @@ -381,6 +441,7 @@ def get_data(self):
for subdata in tqdm(self.raw_data, desc="Loading data"):
# the type_mapper here is loaded in PyG `dataset` type as `transform` attritube
# so the OrbitalMapper can be accessed by self.transform here
subdata: _TrajData
subdata_list = subdata.toAtomicDataList(self.transform)
all_data += subdata_list
return all_data
Expand Down