Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file removed .gitattributes
Empty file.
20 changes: 0 additions & 20 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,6 @@ jobs:
- name: Run formatter and linter
run: pixi run fmt

test:
name: Perform tests
runs-on: ubuntu-latest
timeout-minutes: 15
strategy:
matrix:
environment: [test-py311, test-py312]
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Pixi
uses: prefix-dev/setup-pixi@v0.8.1
with:
pixi-version: v0.40.2
cache: false
environments: ${{ matrix.environment }}
activate-environment: true
- name: Run tests
run: pixi run test

docs:
name: Generate documentation
runs-on: ubuntu-latest
Expand Down
70 changes: 21 additions & 49 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,29 @@ description = "AiiDA plugin for FANS, an FFT-based homogenization solver."
urls = {Documentation = "http://aiida-fans.readthedocs.io/en/latest/", Source = "https://github.com/ethan-shanahan/aiida-fans" }
authors = [{name = "Ethan Shanahan", email = "ethan.shanahan@gmail.com"}]
readme = "README.md"
license = {file = "LICENSE"}
license = "LGPL-3.0-or-later"
license-files = ["LICENSE"]
classifiers = [
"Natural Language :: English",
"Programming Language :: Python",
"Operating System :: POSIX :: Linux",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)",
"Development Status :: 3 - Alpha",
"Framework :: AiiDA"
]
keywords = ["aiida", "plugin", "fans"]
keywords = ["aiida", "fans", "plugin"]
requires-python = ">=3.11"
dependencies = [
"aiida-core>=2.6",
"h5py"
]

# Entry Points
# [project.entry-points."aiida.data"]
# "fans" = "aiida_fans.data:FANSParameters"
[project.entry-points."aiida.calculations"]
"fans.stashed" = "aiida_fans.calculations:FansStashedCalculation"
"fans.fragmented" = "aiida_fans.calculations:FansFragmentedCalculation"
[project.entry-points."aiida.parsers"]
"fans" = "aiida_fans.parsers:FansParser"
# [project.entry-points."aiida.cmdline.data"]
# "fans" = "aiida_fans.cli:data_cli"

# Build System
[build-system]
Expand All @@ -46,63 +42,44 @@ build-backend = "setuptools.build_meta"
channels = ["conda-forge"]
platforms = ["linux-64"]

### pixi: default dependencies (in addition to aiida-core)
[tool.pixi.dependencies]
# None
[tool.pixi.pypi-dependencies]
# None

### pixi: default tasks
[tool.pixi.tasks]
# None

### pixi: features
[tool.pixi.feature.self]
pypi-dependencies = {aiida-fans = { path = ".", editable = true }}
[tool.pixi.feature.plugin]
dependencies = {aiida-fans = "==0.1.5"}
# [tool.pixi.feature.aiida]
# dependencies = {aiida-core = "2.6.*"}
[tool.pixi.feature.fans]
dependencies = {fans = "0.4.*"}
[tool.pixi.feature.py3]
dependencies = {python = "3.*"}
[tool.pixi.feature.ruff]
dependencies = {ruff = "*"}
tasks = {fmt = "ruff check", dummy = "echo dummy", my-dummy="echo my-dummy"}
tasks = {fmt = "ruff check"}
[tool.pixi.feature.build]
pypi-dependencies = {build = "*"}
tasks = {build-dist = "python -m build"}
[tool.pixi.feature.sphinx]
dependencies = {sphinx = "*", sphinx-book-theme = "*"}
tasks = {build-docs = "sphinx-build -M html docs/source docs/build"}
[tool.pixi.feature.pytest]
dependencies = {pytest = "*"}
tasks = {test = "echo dummy test passes"}
[tool.pixi.feature.marimo]
dependencies = {marimo = "0.13.*"}
dependencies = {marimo = "0.14.*"}
tasks = {tutorial = "marimo edit tutorial.py"}
[tool.pixi.feature.py311]
dependencies = {python = "3.11.*"}
[tool.pixi.feature.py312]
dependencies = {python = "3.12.*"}
# [tool.pixi.feature.py313]
# dependencies = {python = "3.13.*"}
[tool.pixi.feature.fans]
dependencies = {fans = "0.4.*"}

### pixi: default environment dependencies
[tool.pixi.dependencies]
aiida-core = "2.6.*"
h5py = "*"
[tool.pixi.pypi-dependencies]
aiida-fans = { path = ".", editable = true }

### pixi: environments
[tool.pixi.environments]
dev = { features = ["self", "ruff", "pytest"], solve-group = "default" }
fmt = { features = ["ruff", "py312"], no-default-feature = true }
dist = { features = ["build", "py312"], no-default-feature = true }
docs = { features = ["sphinx", "py312"], no-default-feature = true }
test-py311 = { features = ["self", "fans", "pytest", "py311"], solve-group = "py311" }
test-py312 = { features = ["self", "fans", "pytest", "py312"], solve-group = "py312" }
# test-py313 = { features = ["self", "fans", "pytest", "py313"], solve-group = "py313" }
tutorial = { features = ["plugin", "fans", "marimo"], no-default-feature = true}
fmt = { no-default-feature = true, features = ["py3", "ruff"] } # CI env
dist = { no-default-feature = true, features = ["py3", "build"] } # CI env
docs = { no-default-feature = true, features = ["py3", "sphinx"] } # CI env
tutorial = { features = ["marimo", "fans"] }


## Build Tools: setuptools_scm
[tool.setuptools_scm]
version_file = "src/aiida_fans/_version.py"


## Style Tools: ruff
[tool.ruff]
extend-exclude = [
Expand Down Expand Up @@ -131,10 +108,5 @@ select = [
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" # Allow unused variables when underscore-prefixed.
pydocstyle = {convention = "google"}

## Test Tools: pytest
[tool.pytest.ini_options]
[tool.coverage]
source = ["src/aiida_fans"]

## Docs Tools: sphinx
# None
60 changes: 33 additions & 27 deletions src/aiida_fans/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,22 @@ def aiida_type(value: Any) -> type[Data]:
"""
match value:
case str():
return DataFactory("core.str") # Str
return DataFactory("core.str") # Str
case int():
return DataFactory("core.int") # Int
return DataFactory("core.int") # Int
case float():
return DataFactory("core.float") # Float
return DataFactory("core.float") # Float
case list():
return DataFactory("core.list") # List
return DataFactory("core.list") # List
case dict():
if all(map(lambda t: isinstance(t, ndarray), value.values())):
return DataFactory("core.array") # ArrayData
return DataFactory("core.array") # ArrayData
else:
return DataFactory("core.dict") # Dict
return DataFactory("core.dict") # Dict
case _:
raise NotImplementedError(f"Received an input of value: {value} with type: {type(value)}")


def fetch(label: str, value: Any) -> list[Node]:
"""Return a list of nodes matching the label and value provided.

Expand All @@ -50,26 +51,31 @@ def fetch(label: str, value: Any) -> list[Node]:
list[Node]: the list of nodes matching the give criteria
"""
datatype = aiida_type(value)
nodes = QueryBuilder(
).append(cls=datatype, tag="n"
).add_filter("n", {"label": label}
).add_filter("n", {"attributes": {"==": datatype(value).base.attributes.all}}
).all(flat=True)
nodes = (
QueryBuilder()
.append(cls=datatype, tag="n")
.add_filter("n", {"label": label})
.add_filter("n", {"attributes": {"==": datatype(value).base.attributes.all}})
.all(flat=True)
)

if datatype != DataFactory("core.array"):
return nodes # type: ignore
return nodes # type: ignore
else:
array_nodes = []
for array_node in nodes:
array_value = {
k: v for k, v in [
(name, array_node.get_array(name)) for name in array_node.get_arraynames() # type: ignore
k: v
for k, v in [
(name, array_node.get_array(name))
for name in array_node.get_arraynames() # type: ignore
]
}
if arraydata_equal(value, array_value):
array_nodes.append(array_node)
return array_nodes


def generate(label: str, value: Any) -> Node:
"""Return a single node with the label and value provided.

Expand All @@ -93,6 +99,7 @@ def generate(label: str, value: Any) -> Node:
else:
raise RuntimeError


def convert(ins: dict[str, Any], path: list[str] = []):
"""Takes a dictionary of inputs and converts the values to their respective Nodes.

Expand All @@ -108,7 +115,8 @@ def convert(ins: dict[str, Any], path: list[str] = []):
else:
ins[k] = generate(".".join([*path, k]), v)

def compile_query(ins: dict[str,Any], qb: QueryBuilder) -> None:

def compile_query(ins: dict[str, Any], qb: QueryBuilder) -> None:
"""Interate over the converted input dictionary and append to the QueryBuilder for each node.

Args:
Expand All @@ -121,18 +129,14 @@ def compile_query(ins: dict[str,Any], qb: QueryBuilder) -> None:
if k in ["microstructure", "error_parameters"] and isinstance(v, dict):
compile_query(v, qb)
else:
qb.append(
cls=type(v),
with_outgoing="calc",
filters={"pk": v.pk}
)
qb.append(cls=type(v), with_outgoing="calc", filters={"pk": v.pk})


def execute_fans(
mode: Literal["Submit", "Run"],
inputs: dict[str, Any],
strategy: Literal["Fragmented", "Stashed"] = "Fragmented",
):
mode: Literal["Submit", "Run"],
inputs: dict[str, Any],
strategy: Literal["Fragmented", "Stashed"] = "Fragmented",
):
"""This utility function simplifies the process of executing aiida-fans jobs.

The only nodes you must provide are the `code` and `microstructure` inputs.
Expand Down Expand Up @@ -191,17 +195,18 @@ def execute_fans(
compile_query(inputs, qb)
results = qb.all(flat=True)
if (count := len(results)) != 0:
print(f"It seems this calculation has already been performed {count} time{"s" if count > 1 else ""}. {results}")
print(f"It seems this calculation has already been performed {count} time{'s' if count > 1 else ''}. {results}")
confirmation = input("Are you sure you want to rerun it? [y/N] ").strip().lower() in ["y", "yes"]
else:
confirmation = True

if confirmation:
match mode:
case "Run":
run(calcjob, inputs) # type: ignore
run(calcjob, inputs) # type: ignore
case "Submit":
submit(calcjob, inputs) # type: ignore
submit(calcjob, inputs) # type: ignore


def submit_fans(
inputs: dict[str, Any],
Expand All @@ -210,6 +215,7 @@ def submit_fans(
"""See `execute_fans` for implementation and usage details."""
execute_fans("Submit", inputs, strategy)


def run_fans(
inputs: dict[str, Any],
strategy: Literal["Fragmented", "Stashed"] = "Fragmented",
Expand Down
Loading