diff --git a/dptb/data/dataset/_default_dataset.py b/dptb/data/dataset/_default_dataset.py index ed84350c..c1eface3 100644 --- a/dptb/data/dataset/_default_dataset.py +++ b/dptb/data/dataset/_default_dataset.py @@ -167,66 +167,105 @@ def from_text_data(cls, @classmethod def from_ase_traj(cls, root: str, - get_Hamiltonian = False, - get_overlap = False, - get_DM = False, - get_eigenvalues = False, - info = None): - - assert not get_Hamiltonian * get_DM, "Hamiltonian and Density Matrix can only loaded one at a time, for which will occupy the same attribute in the AtomicData." - + get_Hamiltonian: bool = False, + get_overlap: bool = False, + get_DM: bool = False, + get_eigenvalues: bool = False, + info: Optional[Dict] = None): + ''' + Build the _TrajData instance by reading the data from the single data directory + that organized in the way compatible with the ASE. + + Parameters + ---------- + root: str + The folder where the data is stored, including the traj file. + get_Hamiltonian: bool + Whether to load the hamiltonian blocks. + get_overlap: bool + Whether to load the overlap blocks. + get_DM: bool + Whether to load the density matrix blocks. + get_eigenvalues: bool + Whether to load the eigenvalues. + info: dict + The description of the data, may be inconsistent with the real data, in this case, + the info will be updated. + ''' + assert not get_Hamiltonian * get_DM, \ + "Hamiltonian and Density Matrix can only loaded one at a time, " + \ + "for which will occupy the same attribute in the AtomicData." + + # read the ase trajectory file... traj_file = glob.glob(f"{root}/*.traj") - assert len(traj_file) == 1, print("only one ase trajectory file can be provided.") - traj = Trajectory(traj_file[0], 'r') - nframes = len(traj) - assert nframes > 0, print("trajectory file is empty.") - if nframes != info.get("nframes", None): - info['nframes'] = nframes - log.info(f"Number of frames ({nframes}) in trajectory file does not match the number of frames in info file.") - - natoms = traj[0].positions.shape[0] - if natoms != info["natoms"]: - info["natoms"] = natoms - - pbc = info.get("pbc",None) - if pbc is None: - pbc = traj[0].pbc.tolist() - info["pbc"] = pbc - - if isinstance(pbc, bool): - pbc = [pbc] * 3 - - if pbc != traj[0].pbc.tolist(): - log.warning("!! PBC setting in info file does not match the PBC setting in trajectory file, we use the one in info json. BE CAREFUL!") - - positions = [] - cell = [] - atomic_numbers = [] - - for atoms in traj: - positions.append(atoms.get_positions()) + assert len(traj_file) == 1, "only one ase trajectory file can be provided." + + atomic_numbers, positions, cell = None, None, None + # use the context manager to avoid memory leak + with Trajectory(traj_file[0], 'r') as traj: + # there are some dimensions of importance: nframe, natom, ... + nframes = len(traj) + assert nframes > 0, "trajectory file is empty." + # if there is discrepancy between nframes in info and traj, then update the info. + if nframes != info.get("nframes", None): + info['nframes'] = nframes + log.info(f"Number of frames ({nframes}) in trajectory file does not match the info file.") - atomic_numbers.append(atoms.get_atomic_numbers()) - if (np.abs(atoms.get_cell()-np.zeros([3,3]))< 1e-6).all(): - cell = None - else: - cell.append(atoms.get_cell()) - - positions = np.array(positions) - positions = positions.reshape(nframes,natoms, 3) - - if cell is not None: - cell = np.array(cell) - cell = cell.reshape(nframes,3, 3) - - atomic_numbers = np.array(atomic_numbers) - atomic_numbers = atomic_numbers.reshape(nframes, natoms) - - data = {} - if cell is not None: - data["cell"] = cell - data["pos"] = positions - data["atomic_numbers"] = atomic_numbers + # assuming there will not be number of atoms change within the trajectory file... + # we check, because the trajectory file does support this. + natoms = np.unique([len(atoms) for atoms in traj]) + assert len(natoms) == 1, "Number of atoms in trajectory file is not consistent." + natoms = natoms[0] + # natoms = traj[0].positions.shape[0] + if natoms != info["natoms"]: + info["natoms"] = natoms + log.info(f"Number of atoms ({natoms}) in trajectory file does not match the info file.") + + # handling the pbc flag + pbc = info.get("pbc", None) + if pbc is None: + # read from the trajectory...however, the same issue also exists here, the pbc may + # change along the trajectory, so we need to check it (only allow one pbc setting) + pbc = np.unique([atoms.pbc.tolist() for atoms in traj]) + assert len(pbc) == 1, "PBC setting in trajectory file is not consistent." + pbc = pbc[0] + assert isinstance(pbc, list) and len(pbc) == 3, \ + f"Unexpected `PBC` format: {pbc}" + info["pbc"] = pbc + # check on the value of pbc + if isinstance(pbc, bool): + pbc = [pbc] * 3 + if pbc != traj[0].pbc.tolist(): + log.warning("!! PBC setting in info file does not match the PBC setting in trajectory file, " + "we use the one in info json. BE CAREFUL!") + + # overwrite the following three to the empty lists + atomic_numbers, positions, cell = [], [], [] + for atoms in traj: + atoms: Atoms # type annotation :) + + atomic_numbers.append(atoms.get_atomic_numbers()) + positions.append(atoms.get_positions()) + # if there is no cell information, then set it to None. However, + # there is also the case that the invalidity of cell is reflected + # by the cell being all zeros. + cell_read = atoms.get_cell() + cell.append(None if np.allclose(cell_read, np.zeros((3, 3)), atol=1e-6) else cell_read) + + # the trajectory reading must be successful and not empty + assert positions + assert atomic_numbers + assert cell + + # this may raise errors about the inhomogenity of the data, or the reshape failed + data = {"pos": np.array(positions).reshape(nframes, natoms, 3), + "atomic_numbers": np.array(atomic_numbers).reshape(nframes, natoms)} + assert len(cell) == nframes + if all(c is not None for c in cell): + data["cell"] = np.array(cell).reshape(nframes, 3, 3) + else: + # otherwise, we expect that all cells are None, the hybrid case is not allowed + assert all(c is None for c in cell) return cls(root=root, data=data, @@ -318,18 +357,47 @@ def toAtomicDataList(self, idp: TypeMapper = None): class DefaultDataset(AtomicInMemoryDataset): - def __init__( - self, - root: str, - info_files: Dict[str, Dict], - url: Optional[str] = None, # seems useless but can't be remove - include_frames: Optional[List[int]] = None, # maybe support in future - type_mapper: TypeMapper = None, - get_Hamiltonian: bool = False, - get_overlap: bool = False, - get_DM: bool = False, - get_eigenvalues: bool = False, - ): + def __init__(self, + root: str, + info_files: Dict[str, Dict], + url: Optional[str] = None, # seems useless but can't be remove + include_frames: Optional[List[int]] = None, # maybe support in future + type_mapper: TypeMapper = None, + get_Hamiltonian: bool = False, + get_overlap: bool = False, + get_DM: bool = False, + get_eigenvalues: bool = False): + ''' + instantiate the default dataset. + + Parameters + ---------- + root : str + root directory of the dataset. + info_files : Dict[str, Dict] + the description of all the "valid" subfolders in the root directory, here the + "valid" means there are data files in the subfolder. + url : Optional[str], optional + not used in DeePTB. see its super class + include_frames : Optional[List[int]], optional + not used in DeePTB. see its super class + type_mapper: TypeMapper, optional + the mapping from orbpair index to reduced matrix element, see docstrings of class + OrbitalMapper for more information + get_Hamiltonian : bool, optional + whether to get the Hamiltonian, by default False + get_overlap : bool, optional + whether to get the overlap, by default False + get_DM : bool, optional + whether to get the density matrix, by default False + get_eigenvalues : bool, optional + whether to get the eigenvalues, by default False + ''' + def build_data(pos_typ: str, **kwargs): + builder = {'ase': _TrajData.from_ase_traj} + build_func = builder.get(pos_typ, _TrajData.from_text_data) + return build_func(**kwargs) + self.root = root self.url = url self.info_files = info_files @@ -345,24 +413,16 @@ def __init__( # get the info here info = info_files[file] # assert "AtomicData_options" in info - assert "r_max" in info - assert "pbc" in info - pbc = info["pbc"] - if info["pos_type"] == "ase": - subdata = _TrajData.from_ase_traj(os.path.join(self.root, file), - get_Hamiltonian, - get_overlap, - get_DM, - get_eigenvalues, - info=info) - else: - subdata = _TrajData.from_text_data(os.path.join(self.root, file), - get_Hamiltonian, - get_overlap, - get_DM, - get_eigenvalues, - info=info) - self.raw_data.append(subdata) + assert all(attr in info for attr in ["r_max", "pbc"]) + pbc = info["pbc"] # not used? + self.raw_data.append( + build_data(pos_typ=info["pos_type"], + root=os.path.join(self.root, file), + get_Hamiltonian=get_Hamiltonian, + get_overlap=get_overlap, + get_DM=get_DM, + get_eigenvalues=get_eigenvalues, + info=info)) # The AtomicData_options is never used here. # Because we always return a list of AtomicData object in `get_data()`. @@ -381,6 +441,7 @@ def get_data(self): for subdata in tqdm(self.raw_data, desc="Loading data"): # the type_mapper here is loaded in PyG `dataset` type as `transform` attritube # so the OrbitalMapper can be accessed by self.transform here + subdata: _TrajData subdata_list = subdata.toAtomicDataList(self.transform) all_data += subdata_list return all_data