Skip to content

Commit 48462af

Browse files
authored
Merge pull request #370 from anyangml2nd/feat/support-vacancy-task
Feat: add vacancy formation task
2 parents dc3889c + ae72bfe commit 48462af

5 files changed

Lines changed: 92 additions & 0 deletions

File tree

lambench/metrics/downstream_tasks_metrics.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ elastic:
2020
metrics: [MAE_G_VRH, MAE_K_VRH]
2121
penalty: success_rate
2222
dummy: {"MAE_G_VRH": 67.5431, "MAE_K_VRH": 136.2597}
23+
vacancy:
24+
domain: Inorganic Materials
25+
metrics: [MAE]
26+
dummy: {"MAE": 4.381}

lambench/metrics/post_process.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def process_domain_specific_for_one_model(model: BaseLargeAtomModel):
117117
"neb",
118118
"wiggle150",
119119
"elastic",
120+
"vacancy",
120121
]:
121122
applicability_results[record.task_name] = record.metrics
122123
return applicability_results

lambench/models/ase_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,11 @@ def evaluate(
258258
fmax = task.calculator_params.get("fmax", 1e-3)
259259
max_steps = task.calculator_params.get("max_steps", 500)
260260
return {"metrics": run_inference(self, task.test_data, fmax, max_steps)}
261+
elif task.task_name == "vacancy":
262+
from lambench.tasks.calculator.vacancy.vacancy import run_inference
263+
264+
assert task.test_data is not None
265+
return {"metrics": run_inference(self, task.test_data)}
261266
else:
262267
raise NotImplementedError(f"Task {task.task_name} is not implemented.")
263268

lambench/tasks/calculator/calculator_tasks.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,6 @@ elastic:
2828
calculator_params:
2929
fmax: 0.001
3030
max_steps: 500
31+
vacancy:
32+
test_data: /bohr/lambench-vacancy-a2xo/v1
33+
calculator_params: null
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""
2+
The test data is retrieved from:
3+
Chem. Mater. 2023, 35, 24, 10619–10634
4+
5+
https://pubs.acs.org/doi/10.1021/acs.chemmater.3c02251
6+
7+
Only 1813 structure pairs are used.
8+
9+
"""
10+
11+
from ase.io import read
12+
import numpy as np
13+
from ase import Atoms
14+
from tqdm import tqdm
15+
from pathlib import Path
16+
17+
from sklearn.metrics import root_mean_squared_error, mean_absolute_error
18+
19+
from lambench.models.ase_models import ASEModel
20+
import logging
21+
22+
23+
def get_oxygen_reference_energy(calc) -> float:
24+
vacuum_size = 30 # Ångströms: Large cell size to ensure vacuum separation
25+
o_o_bond_length = 1.23 # Ångströms: Experimental O-O bond length for O2
26+
cell_vector = vacuum_size
27+
cell = [cell_vector, cell_vector, cell_vector]
28+
center = cell_vector / 2
29+
30+
positions = [
31+
(center, center, center - o_o_bond_length / 2),
32+
(center, center, center + o_o_bond_length / 2),
33+
]
34+
35+
molecular_oxygen = Atoms("O2", positions=positions, cell=cell, pbc=True)
36+
molecular_oxygen.calc = calc
37+
return molecular_oxygen.get_potential_energy() / 2
38+
39+
40+
def run_inference(
41+
model: ASEModel,
42+
test_data: Path,
43+
) -> dict[str, float]:
44+
pristine_structures = read(test_data / "vacancy_pristine_structures.traj", ":")
45+
defect_structures = read(test_data / "vacancy_defect_structures.traj", ":")
46+
labels = np.load(test_data / "vacancy_evf_label.npy")
47+
48+
evf_lab = []
49+
evf_pred = []
50+
calc = model.calc
51+
52+
# Calculate reference energy for oxygen atom
53+
E_o = get_oxygen_reference_energy(calc)
54+
55+
for pristine, defect, label in tqdm(
56+
zip(pristine_structures, defect_structures, labels)
57+
):
58+
natoms_pri = len(pristine)
59+
natoms_def = len(defect)
60+
61+
n_oxygen = natoms_pri - natoms_def
62+
63+
pristine.calc = calc
64+
defect.calc = calc
65+
try:
66+
final = defect.get_potential_energy()
67+
initial = pristine.get_potential_energy()
68+
69+
e_vf = final + n_oxygen * E_o - initial
70+
evf_lab.append(label)
71+
evf_pred.append(e_vf)
72+
73+
except Exception as e:
74+
logging.error(f"Error occurred while processing structures: {e}")
75+
76+
return {
77+
"MAE": mean_absolute_error(evf_lab, evf_pred), # eV
78+
"RMSE": root_mean_squared_error(evf_lab, evf_pred), # eV
79+
}

0 commit comments

Comments
 (0)