Skip to content

Commit

Permalink
tests working
Browse files Browse the repository at this point in the history
  • Loading branch information
loriab committed Oct 5, 2024
1 parent 0f7af72 commit b289c02
Show file tree
Hide file tree
Showing 34 changed files with 486 additions and 328 deletions.
3 changes: 3 additions & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ Breaking Changes
for CLI. @loriab
* as promised, `local_options` has been removed in favor of `task_config`.
* compute and compute_procedure have been merged in favor of the former.
* `compute` learned an optional argument `return_version` to specify the schema_version of the
returned model or dictionary. By default it'll return the input schema_version. If not
determinable, will return v1. @loriab

New Features
++++++++++++
Expand Down
25 changes: 14 additions & 11 deletions qcengine/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ def compute(
program: str,
raise_error: bool = False,
task_config: Optional[Dict[str, Any]] = None,
local_options: Optional[Dict[str, Any]] = None,
return_dict: bool = False,
schema_version: int = -1,
return_version: int = -1,
) -> Union[
"BaseModel", "FailedOperation", Dict[str, Any]
]: # TODO Output base class, was AtomicResult OptimizationResult
Expand All @@ -53,18 +52,16 @@ def compute(
input_data
A QCSchema input specification in dictionary or model from QCElemental.models
program
The CMS program or procedure with which to execute the input.
The CMS program or procedure with which to execute the input. E.g., "psi4", "rdkit", "geometric".
raise_error
Determines if compute should raise an error or not.
retries : int, optional
The number of random tries to retry for.
task_config
A dictionary of local configuration options corresponding to a TaskConfig object.
local_options
Deprecated parameter, renamed to ``task_config``
return_dict
Returns a dict instead of qcelemental.models.AtomicResult
schema_version
Returns a dict instead of qcelemental.models.AtomicResult # TODO base Result class
return_version
The schema version to return. If -1, the input schema_version is used.
Returns
Expand All @@ -75,20 +72,24 @@ def compute(
"""

try:
# models, v1 or v2
output_data = input_data.model_copy()
except AttributeError:
# dicts
output_data = input_data.copy() # lgtm [py/multiple-definition]

with compute_wrapper(capture_output=False, raise_error=raise_error) as metadata:

# Grab the executor and build the input model
# Grab the executor harness
try:
executor = get_procedure(program)
except InputError:
executor = get_program(program)

# Build the model and validate
input_data = executor.build_input_model(input_data) # calls model_wrapper
# * calls model_wrapper with the (Atomic|Optimization|etc)Input for which the harness was designed
# * upon return, input_data is a model of the type (e.g., Atomic) and version (e.g., 1 or 2) the harness prefers. for now, v1.
input_data, input_schema_version = executor.build_input_model(input_data, return_input_schema_version=True)
convert_version = input_schema_version if return_version == -1 else return_version

# Build out task_config
if task_config is None:
Expand All @@ -114,7 +115,9 @@ def compute(
except:
raise

return handle_output_metadata(output_data, metadata, raise_error=raise_error, return_dict=return_dict)
return handle_output_metadata(
output_data, metadata, raise_error=raise_error, return_dict=return_dict, convert_version=convert_version
)


def compute_procedure(*args, **kwargs):
Expand Down
6 changes: 4 additions & 2 deletions qcengine/procedures/berny.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def found(self, raise_error: bool = False) -> bool:
raise_msg="Please install via `pip install pyberny`.",
)

def build_input_model(self, data: Union[Dict[str, Any], "OptimizationInput"]) -> "OptimizationInput":
return self._build_model(data, OptimizationInput)
def build_input_model(
self, data: Union[Dict[str, Any], "OptimizationInput"], *, return_input_schema_version: bool = False
) -> "OptimizationInput":
return self._build_model(data, "OptimizationInput", return_input_schema_version=return_input_schema_version)

def compute(
self, input_data: "OptimizationInput", config: "TaskConfig"
Expand Down
6 changes: 4 additions & 2 deletions qcengine/procedures/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def get_version(self) -> str:

return self.version_cache[which_prog]

def build_input_model(self, data: Union[Dict[str, Any], "OptimizationInput"]) -> "OptimizationInput":
return self._build_model(data, OptimizationInput)
def build_input_model(
self, data: Union[Dict[str, Any], "OptimizationInput"], *, return_input_schema_version: bool = False
) -> "OptimizationInput":
return self._build_model(data, "OptimizationInput", return_input_schema_version=return_input_schema_version)

def compute(self, input_model: "OptimizationInput", config: "TaskConfig") -> "OptimizationResult":
try:
Expand Down
28 changes: 25 additions & 3 deletions qcengine/procedures/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
from typing import Any, Dict, Union
from typing import Any, Dict, Tuple, Union

import qcelemental
from pydantic import BaseModel, ConfigDict

from ..util import model_wrapper
Expand Down Expand Up @@ -52,12 +53,33 @@ def found(self, raise_error: bool = False) -> bool:
If the proceudre was found or not.
"""

def _build_model(self, data: Dict[str, Any], model: "BaseModel") -> "BaseModel":
def _build_model(
self, data: Dict[str, Any], model: "BaseModel", /, *, return_input_schema_version: bool = False
) -> Union["BaseModel", Tuple["BaseModel", int]]:
"""
Quick wrapper around util.model_wrapper for inherited classes
"""

return model_wrapper(data, model)
v1_model = getattr(qcelemental.models.v1, model)
v2_model = getattr(qcelemental.models.v2, model)

if isinstance(data, v1_model):
mdl = model_wrapper(data, v1_model)
elif isinstance(data, v2_model):
mdl = model_wrapper(data, v2_model)
elif isinstance(data, dict):
# remember these are user-provided dictionaries, so they'll have the mandatory fields,
# like driver, not the helpful discriminator fields like schema_version.

# for now, the two dictionaries look the same, so cast to the one we want
# note that this prevents correctly identifying the user schema version when dict passed in, so either as_v1/None or as_v2 will fail
mdl = model_wrapper(data, v1_model) # TODO v2

input_schema_version = mdl.schema_version
if return_input_schema_version:
return mdl.convert_v(1), input_schema_version # non-psi4 return_dict=False fail w/o this
else:
return mdl.convert_v(1)

def get_version(self) -> str:
"""Finds procedure, extracts version, returns normalized version string.
Expand Down
6 changes: 4 additions & 2 deletions qcengine/procedures/nwchem_opt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ def get_version(self) -> str:
nwc_harness = NWChemHarness()
return nwc_harness.get_version()

def build_input_model(self, data: Union[Dict[str, Any], "OptimizationInput"]) -> OptimizationInput:
return self._build_model(data, OptimizationInput)
def build_input_model(
self, data: Union[Dict[str, Any], "OptimizationInput"], *, return_input_schema_version: bool = False
) -> "OptimizationInput":
return self._build_model(data, "OptimizationInput", return_input_schema_version=return_input_schema_version)

def compute(self, input_data: OptimizationInput, config: TaskConfig) -> "BaseModel":
nwc_harness = NWChemHarness()
Expand Down
6 changes: 4 additions & 2 deletions qcengine/procedures/optking.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def found(self, raise_error: bool = False) -> bool:
raise_msg="Please install via `conda install optking -c conda-forge`.",
)

def build_input_model(self, data: Union[Dict[str, Any], "OptimizationInput"]) -> "OptimizationInput":
return self._build_model(data, OptimizationInput)
def build_input_model(
self, data: Union[Dict[str, Any], "OptimizationInput"], *, return_input_schema_version: bool = False
) -> "OptimizationInput":
return self._build_model(data, "OptimizationInput", return_input_schema_version=return_input_schema_version)

def get_version(self) -> str:
self.found(raise_error=True)
Expand Down
6 changes: 4 additions & 2 deletions qcengine/procedures/torsiondrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def found(self, raise_error: bool = False) -> bool:
raise_msg="Please install via `conda install torsiondrive -c conda-forge`.",
)

def build_input_model(self, data: Union[Dict[str, Any], "TorsionDriveInput"]) -> "TorsionDriveInput":
return self._build_model(data, TorsionDriveInput)
def build_input_model(
self, data: Union[Dict[str, Any], "TorsionDriveInput"], *, return_input_schema_version: bool = False
) -> "TorsionDriveInput":
return self._build_model(data, "TorsionDriveInput", return_input_schema_version=return_input_schema_version)

def _compute(self, input_model: "TorsionDriveInput", config: "TaskConfig"):

Expand Down
44 changes: 34 additions & 10 deletions qcengine/programs/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import logging
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union

import qcelemental
from pydantic import BaseModel, ConfigDict
from qcelemental.models import AtomicInput, AtomicResult, FailedOperation

from qcengine.config import TaskConfig
from qcengine.exceptions import KnownErrorException
Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(self, **kwargs):
super().__init__(**{**self._defaults, **kwargs})

@abc.abstractmethod
def compute(self, input_data: AtomicInput, config: TaskConfig) -> Union[AtomicResult, FailedOperation]:
def compute(self, input_data: "AtomicInput", config: TaskConfig) -> Union["AtomicResult", "FailedOperation"]:
"""Top-level compute method to be implemented for every ProgramHarness
Note:
Expand Down Expand Up @@ -68,13 +68,34 @@ def found(raise_error: bool = False) -> bool:

# def _build_model

def build_input_model(self, data: Dict[str, Any]) -> "AtomicInput":
def build_input_model(
self, data: Dict[str, Any], *, return_input_schema_version: bool = False
) -> Union["AtomicInput", Tuple["AtomicInput", int]]:
"""
Quick wrapper around util.model_wrapper for inherited classes
"""
from qcelemental.models.v1 import AtomicInput # TODO v2

return model_wrapper(data, AtomicInput)
# Note: Someday when the multiple QCSchema versions QCEngine supports are all within the
# Pydantic v2 API base class, this can use discriminated unions instead of logic.

if isinstance(data, qcelemental.models.v1.AtomicInput):
mdl = model_wrapper(data, qcelemental.models.v1.AtomicInput)
elif isinstance(data, qcelemental.models.v2.AtomicInput):
mdl = model_wrapper(data, qcelemental.models.v2.AtomicInput)
elif isinstance(data, dict):
# remember these are user-provided dictionaries, so they'll have the mandatory fields,
# like driver, not the helpful discriminator fields like schema_version.

# for now, the two dictionaries look the same, so cast to the one we want
# note that this prevents correctly identifying the user schema version when dict passed in, so either as_v1/None or as_v2 will fail
mdl = model_wrapper(
data, qcelemental.models.v1.AtomicInput
) # TODO v2 # TODO kill off excuse_as_v2, now fix 2->-1 in schema_versions

input_schema_version = mdl.schema_version
if return_input_schema_version:
return mdl.convert_v(1), input_schema_version # non-psi4 return_dict=False fail w/o this
else:
return mdl.convert_v(1)

def get_version(self) -> str:
"""Finds program, extracts version, returns normalized version string.
Expand All @@ -88,7 +109,7 @@ def get_version(self) -> str:
## Computers

def build_input(
self, input_model: AtomicInput, config: TaskConfig, template: Optional[str] = None
self, input_model: "AtomicInput", config: TaskConfig, template: Optional[str] = None
) -> Dict[str, Any]:
raise ValueError("build_input is not implemented for {}.", self.__class__)

Expand Down Expand Up @@ -124,10 +145,10 @@ class ErrorCorrectionProgramHarness(ProgramHarness, abc.ABC):
``ErrorCorrectionProgramHarness`` and used to determine if/how to re-run the computation.
"""

def _compute(self, input_data: AtomicInput, config: TaskConfig) -> AtomicResult:
def _compute(self, input_data: "AtomicInput", config: TaskConfig) -> "AtomicResult":
raise NotImplementedError()

def compute(self, input_data: AtomicInput, config: TaskConfig) -> AtomicResult:
def compute(self, input_data: "AtomicInput", config: TaskConfig) -> "AtomicResult":
# Get the error correction configuration
error_policy = input_data.protocols.error_correction

Expand Down Expand Up @@ -163,7 +184,10 @@ def compute(self, input_data: AtomicInput, config: TaskConfig) -> AtomicResult:
keyword_updates = e.create_keyword_update(local_input_data)
new_keywords = local_input_data.keywords.copy()
new_keywords.update(keyword_updates)
local_input_data = AtomicInput(**local_input_data.dict(exclude={"keywords"}), keywords=new_keywords)
# TODO v2
local_input_data = qcelemental.models.v1.AtomicInput(
**local_input_data.dict(exclude={"keywords"}), keywords=new_keywords
)

# Store the error details and mitigations employed
observed_errors[e.error_name] = {"details": e.details, "keyword_updates": keyword_updates}
Expand Down
8 changes: 5 additions & 3 deletions qcengine/programs/tests/standard_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
pp = pprint.PrettyPrinter(width=120)


def runner_asserter(inp, ref_subject, method, basis, tnm, scramble, frame, models):
def runner_asserter(inp, ref_subject, method, basis, tnm, scramble, frame, models, retver):

qcprog = inp["call"]
qc_module_in = inp["qc_module"] # returns "<qcprog>"|"<qcprog>-<module>" # input-specified routing
Expand Down Expand Up @@ -143,14 +143,16 @@ def runner_asserter(inp, ref_subject, method, basis, tnm, scramble, frame, model
errtype, errmatch, reason = inp["error"]
with pytest.raises(errtype) as e:
atin = checkver_and_convert(atin, tnm, "pre")
qcng.compute(atin, qcprog, raise_error=True, return_dict=True, task_config=local_options)
qcng.compute(
atin, qcprog, raise_error=True, return_dict=True, task_config=local_options, return_version=retver
)

assert re.search(errmatch, str(e.value)), f"Not found: {errtype} '{errmatch}' in {e.value}"
# _recorder(qcprog, qc_module_in, driver, method, reference, fcae, scf_type, corl_type, "error", "nyi: " + reason)
return

atin = checkver_and_convert(atin, tnm, "pre")
wfn = qcng.compute(atin, qcprog, raise_error=True, task_config=local_options)
wfn = qcng.compute(atin, qcprog, raise_error=True, task_config=local_options, return_version=retver)
wfn = checkver_and_convert(wfn, tnm, "post")

print("WFN")
Expand Down
8 changes: 6 additions & 2 deletions qcengine/programs/tests/test_adcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@ def h2o_data():

@using("adcc")
def test_run(h2o_data, schema_versions, request):
models, _ = schema_versions
models, retver, _ = schema_versions
h2o = models.Molecule.from_data(h2o_data)

inp = models.AtomicInput(
molecule=h2o, driver="properties", model={"method": "adc2", "basis": "sto-3g"}, keywords={"n_singlets": 3}
)

inp = checkver_and_convert(inp, request.node.name, "pre")
ret = qcng.compute(inp, "adcc", raise_error=True, task_config={"ncores": 1}, return_dict=True)
ret = qcng.compute(
inp, "adcc", raise_error=True, task_config={"ncores": 1}, return_dict=True, return_version=retver
)
ret = checkver_and_convert(ret, request.node.name, "post")
# note dict-out

ref_excitations = np.array([0.0693704245883876, 0.09773854881340478, 0.21481589246935925])
ref_hf_energy = -74.45975898670224
Expand Down
6 changes: 4 additions & 2 deletions qcengine/programs/tests/test_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.fixture
def clsd_open_pmols(schema_versions):
models, _ = schema_versions
models, _, _ = schema_versions
frame_not_important = {
name[:-4]: models.Molecule.from_data(smol, name=name[:-4])
for name, smol in std_molecules.items()
Expand Down Expand Up @@ -94,6 +94,7 @@ def clsd_open_pmols(schema_versions):
],
)
def test_hf_alignment(inp, scramble, frame, driver, basis, subjects, clsd_open_pmols, request, schema_versions):
models, retver, _ = schema_versions
runner_asserter(
*_processor(
inp,
Expand All @@ -102,7 +103,8 @@ def test_hf_alignment(inp, scramble, frame, driver, basis, subjects, clsd_open_p
subjects,
clsd_open_pmols,
request,
schema_versions[0],
models,
retver,
driver,
"hf",
scramble=scramble,
Expand Down
Loading

0 comments on commit b289c02

Please sign in to comment.