-
Notifications
You must be signed in to change notification settings - Fork 24
refactor: reorganize imports and improve dataset builder logic in build.py #288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughRefactored dataset building logic in Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~30–40 minutes
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
dptb/data/build.py (1)
327-374: Fix dict cutoff comparison bug and support int cutoffs incheck_cutoffsIn
check_cutoffs, the dict branches currently compare the entire dict to a scalar:if isinstance(cutoff_options['r_max'], dict): ... for key in cutoff_options['r_max']: ... assert self.r_max >= cutoff_options['r_max'][key], ...This will raise
TypeError: '>=' not supported between instances of 'dict' and 'float'as soon as a dict-valued cutoff is used, even if the dataset cutoffs are correct. The same pattern appears forer_maxandoer_max.Also, the scalar branches only handle
float, sointcutoffs from model options won’t be checked.A safer implementation would compare per-key values and treat both
floatandintas numeric:- cutoff_options = collect_cutoffs(model.model_options) - if isinstance(cutoff_options['r_max'],dict): - assert isinstance(self.r_max,dict), "The r_max in model is a dict, but in dataset it is not." - for key in cutoff_options['r_max']: - if key not in self.r_max: - log.error(f"The key {key} in r_max is not defined in dataset") - raise ValueError(f"The key {key} in r_max is not defined in dataset") - assert self.r_max >= cutoff_options['r_max'][key], f"The r_max in model shoule be smaller than in dataset for {key}." - - elif isinstance(cutoff_options['r_max'],float): - assert isinstance(self.r_max,float), "The r_max in model is a float, but in dataset it is not." - assert self.r_max >= cutoff_options['r_max'], "The r_max in model shoule be smaller than in dataset." + cutoff_options = collect_cutoffs(model.model_options) + if isinstance(cutoff_options['r_max'], dict): + assert isinstance(self.r_max, dict), "The r_max in model is a dict, but in dataset it is not." + for key, model_r in cutoff_options['r_max'].items(): + if key not in self.r_max: + log.error(f"The key {key} in r_max is not defined in dataset") + raise ValueError(f"The key {key} in r_max is not defined in dataset") + assert self.r_max[key] >= model_r, ( + f"The r_max in model should be smaller than in dataset for {key}." + ) + elif isinstance(cutoff_options['r_max'], (float, int)): + assert isinstance(self.r_max, (float, int)), "The r_max in model is scalar, but in dataset it is not." + assert self.r_max >= cutoff_options['r_max'], ( + "The r_max in model should be smaller than in dataset." + ) @@ - if isinstance(cutoff_options['er_max'],dict): + if isinstance(cutoff_options['er_max'], dict): assert isinstance(self.er_max,dict), "The er_max in model is a dict, but in dataset it is not." - for key in cutoff_options['er_max']: - if key not in self.er_max: - log.error(f"The key {key} in er_max is not defined in dataset") - raise ValueError(f"The key {key} in er_max is not defined in dataset") - - assert self.er_max >= cutoff_options['er_max'][key], f"The er_max in model shoule be smaller than in dataset for {key}." - - elif isinstance(cutoff_options['er_max'],float): - assert isinstance(self.er_max,float), "The er_max in model is a float, but in dataset it is not." - assert self.er_max >= cutoff_options['er_max'], "The er_max in model shoule be smaller than in dataset." + for key, model_er in cutoff_options['er_max'].items(): + if key not in self.er_max: + log.error(f"The key {key} in er_max is not defined in dataset") + raise ValueError(f"The key {key} in er_max is not defined in dataset") + assert self.er_max[key] >= model_er, ( + f"The er_max in model should be smaller than in dataset for {key}." + ) + elif isinstance(cutoff_options['er_max'], (float, int)): + assert isinstance(self.er_max, (float, int)), "The er_max in model is scalar, but in dataset it is not." + assert self.er_max >= cutoff_options['er_max'], ( + "The er_max in model should be smaller than in dataset." + ) @@ - if isinstance(cutoff_options['oer_max'],dict): + if isinstance(cutoff_options['oer_max'], dict): assert isinstance(self.oer_max,dict), "The oer_max in model is a dict, but in dataset it is not." - for key in cutoff_options['oer_max']: - if key not in self.oer_max: - log.error(f"The key {key} in oer_max is not defined in dataset") - raise ValueError(f"The key {key} in oer_max is not defined in dataset") - - assert self.oer_max >= cutoff_options['oer_max'][key], f"The oer_max in model shoule be smaller than in dataset for {key}." - elif isinstance(cutoff_options['oer_max'],float): - assert isinstance(self.oer_max,float), "The oer_max in model is a float, but in dataset it is not." - assert self.oer_max >= cutoff_options['oer_max'], "The oer_max in model shoule be smaller than in dataset." + for key, model_oer in cutoff_options['oer_max'].items(): + if key not in self.oer_max: + log.error(f"The key {key} in oer_max is not defined in dataset") + raise ValueError(f"The key {key} in oer_max is not defined in dataset") + assert self.oer_max[key] >= model_oer, ( + f"The oer_max in model should be smaller than in dataset for {key}." + ) + elif isinstance(cutoff_options['oer_max'], (float, int)): + assert isinstance(self.oer_max, (float, int)), "The oer_max in model is scalar, but in dataset it is not." + assert self.oer_max >= cutoff_options['oer_max'], ( + "The oer_max in model should be smaller than in dataset." + )This function is currently unusable for dict cutoffs, so I’d consider this a blocker to relying on
check_cutoffs.
🧹 Nitpick comments (5)
dptb/data/build.py (5)
30-117: Cleanupdataset_from_configdocstring and import fallback handlingThe implementation looks reasonable and fits the
dptb.dataorganization, but there are a couple of polish items:
- The docstring still refers to
nequiptypes and old test paths; updating it to referencedptband the actual tests in this repo would avoid confusion.- The broad
except Exception:around the dynamic import can hide real bugs (e.g., typos in the module path) and silently fall back to scanningdptb.data. Narrowing this toImportError/AttributeErrorand logging the underlying exception before falling back would make failures easier to debug.Example patch:
- """initialize database based on a config instance + """Initialize a dataset based on a config instance. @@ - config (dict, nequip.utils.Config): dict/object that store all the parameters + config (dict): mapping that stores all the parameters @@ - dataset (nequip.data.AtomicDataset) + dataset (dptb.data.AtomicDataset) @@ - else: - try: - module_name = ".".join(config_dataset.split(".")[:-1]) - class_name = ".".join(config_dataset.split(".")[-1:]) - class_name = getattr(import_module(module_name), class_name) - except Exception: - # ^ TODO: don't catch all Exception + else: + try: + module_name = ".".join(config_dataset.split(".")[:-1]) + class_name = ".".join(config_dataset.split(".")[-1:]) + class_name = getattr(import_module(module_name), class_name) + except (ImportError, AttributeError) as exc: + log.debug( + "Falling back to resolving dataset %s from dptb.data due to %r", + config_dataset, + exc, + ) + # default class defined in dptb.data or dptb.data.dataset
166-173: Make cutoff warning less noisy and clarifyif_check_cutoffsusageRight now every
DatasetBuilder.__call__logs"The cutoffs in data and model are not checked. be careful!"
unconditionally, even if the caller never intends to run
check_cutoffs, or has already validated cutoffs elsewhere. Also,self.if_check_cutoffsis set here but not used in this method.Consider:
- Logging this warning only once per
DatasetBuilderinstance, or only when somecheck_cutoffs-related flag is false.- Or, if
if_check_cutoffsis meant to control this behavior, using it to gate the warning instead of resetting it on every call.This will make logs cleaner when building many datasets in a loop.
175-194: Avoidassertfor user-facing validation in__call__Several important validations use
assert:
- dataset type membership (Lines 177–178)
- non-
Noneprefix (Lines 184–185)- presence of trajectory files in a folder (Lines 189–190)
- at least one valid folder found (Lines 192–193)
Since
assertcan be disabled withpython -O, these checks may silently disappear, and invalid inputs could slip through (e.g., an unsupporteddataset_typefalling through to the LMDB branch).It would be safer to replace these with explicit exceptions:
- assert dataset_type in ["DefaultDataset", "DeePHDataset", "HDF5Dataset", "LMDBDataset"], \ - f"The dataset type {dataset_type} is not supported. Please check the type." + if dataset_type not in ["DefaultDataset", "DeePHDataset", "HDF5Dataset", "LMDBDataset"]: + raise ValueError(f"The dataset type {dataset_type} is not supported. Please check the type.") @@ - assert prefix is not None, \ - "The prefix is not provided. Please provide the prefix to select the trajectory folders." + if prefix is None: + raise ValueError( + "The prefix is not provided. Please provide the prefix to select the trajectory folders." + ) @@ - assert any(folder.glob(f'*.{ext}') for ext in ['dat', 'traj', 'h5', 'mdb']), \ - f'{folder} does not have the proper traj data files. Please check the data files.' + if not any(folder.glob(f'*.{ext}') for ext in ['dat', 'traj', 'h5', 'mdb']): + raise ValueError( + f"{folder} does not have the proper traj data files. Please check the data files." + ) @@ - assert isinstance(valid_folders, list) and len(valid_folders) > 0, \ - "No trajectory folders are found. Please check the prefix." + if not valid_folders: + raise ValueError("No trajectory folders are found. Please check the prefix.")
195-236: Double‑check LMDBinfo.jsonbehavior and simplify long error messageTwo points in the
info.jsonhandling:
LMDB semantics change:
Fordataset_type == "LMDBDataset", root-levelinfo.jsonis explicitly ignored and per-folder entries are set to{}(Lines 206–220). Previously, LMDB flows typically usednormalize_lmdbsetinfoto carry per-trajectory metadata (e.g., pbc, basis). IfLMDBDatasetstill relies on that, this change could break existing LMDB configurations.Please verify LMDB-related tests and examples; if LMDB really no longer needs
info.json, a short comment explaining that and possibly removing the unusednormalize_lmdbsetinfoimport would make the intent clearer. Otherwise, we should probably route LMDB info throughnormalize_lmdbsetinfohere.Long error message (Ruff TRY003):
The ValueError raised when neither per-folder nor publicinfo.jsonexists (Lines 224–226) has a long inline message, which Ruff flags. You could keep the detailed explanation in the log and shorten the exception:- log.error(f"for {dataset_type} type, the info.json is not properly provided for `{folder}`") - raise ValueError(f"for {dataset_type} type, the info.json is not properly provided for `{folder}`") + msg = f"info.json is not properly provided for folder `{folder}` and dataset type {dataset_type}" + log.error(msg) + raise ValueError(msg)
238-247: Consider deepcopying cutoff dicts when attaching toinfo_filesThe loop attaching cutoffs:
for folder in info_files: info_files[folder].update({'r_max': r_max, 'er_max': er_max, 'oer_max': oer_max})works fine for scalar cutoffs, but if any of
r_max,er_max, oroer_maxare dicts and later code mutates them per-trajectory, all folders will share the same dict instance.If per-trajectory mutation is possible, consider deepcopying here:
from copy import deepcopy @@ for folder in info_files: - info_files[folder].update({'r_max' : r_max, - 'er_max' : er_max, - 'oer_max': oer_max}) + info_files[folder].update({ + 'r_max': deepcopy(r_max), + 'er_max': deepcopy(er_max), + 'oer_max': deepcopy(oer_max), + })If you know these are treated as read-only, the current approach is fine.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
dptb/data/build.py(2 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: CR
Repo: deepmodeling/DeePTB PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-26T21:32:19.567Z
Learning: Structure data loading and processing code in the `dptb/data` subpackage
Learnt from: CR
Repo: deepmodeling/DeePTB PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-26T21:32:19.567Z
Learning: Organize code into a main `dptb` package with subpackages for different functionalities (data, nn, negf, etc.)
📚 Learning: 2025-11-26T21:32:19.567Z
Learnt from: CR
Repo: deepmodeling/DeePTB PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-26T21:32:19.567Z
Learning: Structure data loading and processing code in the `dptb/data` subpackage
Applied to files:
dptb/data/build.py
📚 Learning: 2025-11-26T21:32:19.567Z
Learnt from: CR
Repo: deepmodeling/DeePTB PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-26T21:32:19.567Z
Learning: Organize code into a main `dptb` package with subpackages for different functionalities (data, nn, negf, etc.)
Applied to files:
dptb/data/build.py
🧬 Code graph analysis (1)
dptb/data/build.py (7)
dptb/data/transforms.py (1)
OrbitalMapper(395-889)dptb/data/dataset/lmdb_dataset.py (2)
len(79-80)LMDBDataset(25-347)dptb/utils/tools.py (1)
j_loader(381-407)dptb/utils/argcheck.py (1)
normalize_setinfo(1561-1567)dptb/data/dataset/_deeph_dataset.py (1)
DeePHE3Dataset(21-248)dptb/data/dataset/_default_dataset.py (1)
DefaultDataset(319-569)dptb/data/dataset/_hdf5_dataset.py (1)
HDF5Dataset(145-374)
🪛 Ruff (0.14.7)
dptb/data/build.py
226-226: Avoid specifying long messages outside the exception class
(TRY003)
| # after preprocessing, we can build the dataset | ||
| if dataset_type == "DeePHDataset": | ||
| return DeePHE3Dataset(root=root, | ||
| type_mapper=idp, | ||
| get_Hamiltonian=get_Hamiltonian, | ||
| get_eigenvalues=get_eigenvalues, | ||
| info_files=info_files) | ||
| elif dataset_type == "DefaultDataset": | ||
| return DefaultDataset(root=root, | ||
| type_mapper=idp, | ||
| get_Hamiltonian=get_Hamiltonian, | ||
| get_overlap=get_overlap, | ||
| get_DM=get_DM, | ||
| get_eigenvalues=get_eigenvalues, | ||
| info_files=info_files) | ||
| elif dataset_type == "HDF5Dataset": | ||
| return HDF5Dataset(root=root, | ||
| type_mapper=idp, | ||
| get_Hamiltonian=get_Hamiltonian, | ||
| get_overlap=get_overlap, | ||
| get_DM=get_DM, | ||
| get_eigenvalues=get_eigenvalues, | ||
| info_files=info_files) | ||
| else: | ||
| assert dataset_type == "LMDBDataset" | ||
| return LMDBDataset(root=root, | ||
| type_mapper=idp, | ||
| orthogonal=orthogonal, | ||
| get_Hamiltonian=get_Hamiltonian, | ||
| get_overlap=get_overlap, | ||
| get_DM=get_DM, | ||
| get_eigenvalues=get_eigenvalues, | ||
| info_files=info_files) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's check the build.py file to find the __call__ method and from_model method
cd dptb/data && wc -l build.pyRepository: deepmodeling/DeePTB
Length of output: 182
🏁 Script executed:
# Read the relevant sections of build.py to understand the method signatures
cat -n dptb/data/build.py | head -100Repository: deepmodeling/DeePTB
Length of output: 182
🏁 Script executed:
# Search for the __call__ method definition in build.py
rg -A 20 "def __call__" dptb/data/build.pyRepository: deepmodeling/DeePTB
Length of output: 182
🏁 Script executed:
# Search for the from_model method definition
rg -A 20 "def from_model" dptb/data/build.pyRepository: deepmodeling/DeePTB
Length of output: 182
🏁 Script executed:
# Look at the dataset class definitions to see if they accept **kwargs
fd "\.py$" dptb/data | xargs rg -l "class DeePHE3Dataset|class DefaultDataset|class HDF5Dataset|class LMDBDataset"Repository: deepmodeling/DeePTB
Length of output: 182
🌐 Web query:
DeePTB dptb/data/build.py DatasetBuilder __call__ github
💡 Result:
You can find the DatasetBuilder.call implementation in dptb/data/build.py of the DeePTB repo on GitHub (shows the full function signature and logic for building DefaultDataset / DeePHDataset / HDF5Dataset / LMDBDataset). Source (raw file): [1]; repo page: [2].
Sources:
[1] DeePTB dptb/data/build.py (raw). https://github.com/deepmodeling/DeePTB/blob/main/dptb/data/build.py?raw=true.
[2] deepmodeling/DeePTB GitHub repository. https://github.com/deepmodeling/DeePTB.
🌐 Web query:
site:github.com/deepmodeling/DeePTB dptb/data/build.py __call__ kwargs
💡 Result:
The DatasetBuilder.call signature (from dptb/data/build.py) is:
def call(self,
root: str,
r_max: Union[float,int,dict],
er_max: Union[float,int,dict]=None,
oer_max: Union[float,int,dict]=None,
type: str = "DefaultDataset",
prefix: str = None,
separator: str = '.',
get_Hamiltonian: bool = False,
get_overlap: bool = False,
get_DM: bool = False,
get_eigenvalues: bool = False,
orthogonal: bool = False,
basis: str = None,
**kwargs
)
Notes:
- These named args (root, r_max, er_max, oer_max, type, prefix, separator, get_Hamiltonian, get_overlap, get_DM, get_eigenvalues, orthogonal, basis) are the documented/used parameters; **kwargs is accepted but not referenced elsewhere in the function (i.e., extra keyword args are accepted but ignored). [1]
Source: dptb/data/build.py (DeePTB repository). [1]
🏁 Script executed:
# Fetch the raw build.py file from GitHub
curl -s "https://github.com/deepmodeling/DeePTB/blob/main/dptb/data/build.py?raw=true" | head -300Repository: deepmodeling/DeePTB
Length of output: 182
🏁 Script executed:
# Get more of the file to find __call__ method
curl -s "https://github.com/deepmodeling/DeePTB/blob/main/dptb/data/build.py?raw=true" | tail -n +200 | head -150Repository: deepmodeling/DeePTB
Length of output: 182
🏁 Script executed:
# Search for dataset class definitions
curl -s "https://github.com/deepmodeling/DeePTB/blob/main/dptb/data/__init__.py?raw=true"Repository: deepmodeling/DeePTB
Length of output: 182
🏁 Script executed:
# Look for the dataset classes - try common locations
for file in dptb/data/dataset.py dptb/data/datasets.py dptb/datasets.py; do
echo "=== Checking $file ==="
curl -s "https://raw.githubusercontent.com/deepmodeling/DeePTB/main/$file" | head -200
doneRepository: deepmodeling/DeePTB
Length of output: 182
🌐 Web query:
site:github.com/deepmodeling/DeePTB "class DeePHE3Dataset" OR "class DefaultDataset" OR "class HDF5Dataset" OR "class LMDBDataset" __init__
💡 Result:
I searched the deepmodeling/DeePTB repository for those class names.
Findings:
- "DefaultDataset" is referenced in the DeePTB docs (data_options example). [1][2]
- I did not find occurrences of a class named "DeePHE3Dataset", nor explicit class definitions named "HDF5Dataset" or "LMDBDataset" in the repository search results I could access.
Sources:
- deepmodeling/DeePTB repository main page. [1]
- DeePTB docs (quick start input / data_options showing DefaultDataset). [2]
If you want, I can (a) do a deeper code search through the repo files (open specific subfolders) and return exact file/line hits, or (b) provide direct links to any matching source files. Which would you prefer?
🌐 Web query:
deepmodeling/DeePTB dataset classes __init__ signature DefaultDataset HDF5Dataset LMDBDataset
💡 Result:
I can fetch those exact init signatures from the DeePTB repo and return them. Do you want me to look them up on GitHub now?
__call__ accepts **kwargs but silently ignores them instead of forwarding to dataset constructors
The __call__ method signature includes **kwargs, but these are never passed to any of the four dataset constructors (DeePHE3Dataset, DefaultDataset, HDF5Dataset, LMDBDataset). This means callers cannot configure dataset-specific options via build_dataset(..., some_option=...) — such arguments are silently dropped, which is surprising and error-prone.
Either forward **kwargs to the dataset constructors, or remove the parameter from __call__ and from_model to be explicit that extra options are not supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@coderabbitai yes I noticed this, too. But this is not the problem brought about by this PR, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kirk0830 You're absolutely right! This is a pre-existing issue with the **kwargs signature, not something introduced by your refactoring. Your PR is focused on improving the folder discovery logic, Path usage, and info.json handling – all good changes.
The **kwargs was already there and remains unused in the original code as well. If you'd like to address it in a future PR (either by forwarding the kwargs or removing the parameter), I can help open a tracking issue for it. Otherwise, feel free to keep this PR focused on the refactoring work you've done.
Summary by CodeRabbit
New Features
Improvements
✏️ Tip: You can customize this high-level summary in your review settings.