Skip to content

Commit

Permalink
add warning for group member overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Apr 2, 2024
1 parent 4e2ded4 commit d777056
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 21 deletions.
55 changes: 34 additions & 21 deletions cat_scaling/relation/descriptors.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
# TODO: add warning for group member overlap

"""Helper class to record descriptors."""

from __future__ import annotations

import warnings
from typing import Optional


class Descriptors:
def __init__(self, groups: dict, method: Optional[str] = None) -> None:
self._groups = groups
self._method = method
self.groups = groups
self.method = method

def __len__(self) -> int:
"""Number of descriptors."""
Expand All @@ -23,32 +22,47 @@ def groups(self) -> dict[str, Optional[list[str]]]:
For example for CO2 to CH4 reduction reaction:
With "traditional" method:
groups = {
"*CO": ["*COOH", "*CHO", "*CH2O"], # C-centered group
"*OH": ["*OCH3", "*O"] # O-Centered group
}
groups = {
"*CO": ["*COOH", "*CHO", "*CH2O"], # C-centered group
"*OH": ["*OCH3", "*O"] # O-Centered group
}
With "adaptive" method:
groups = {
"*CO": None,
"*OH": None
}
groups = {
"*CO": None,
"*OH": None
}
"""

return self._groups

@groups.setter
def groups(self, groups: dict[str, Optional[list[str]]]):
"""Property:groups setter.
Warnings:
A warning would be raised if group members overlap.
"""

if not isinstance(groups, dict):
raise TypeError("Expect groups as dict.")

for key, value in groups.items():
if not isinstance(key, str):
all_members = [] # for check group member overlap

for descriptor, members in groups.items():
if not isinstance(descriptor, str):
raise TypeError("Keys in groups dictionary must be strings.")
if value is not None and not isinstance(value, list):
raise TypeError(
"Group members must be lists of strings or None."
)
if members is not None:
if not isinstance(members, list):
raise TypeError(
"Group members must be lists of strings or None."
)

all_members.extend(members)

# Check for group member overlap
if len(all_members) != len(set(all_members)):
warnings.warn("Descriptor group members overlap.")

self._groups = groups

Expand All @@ -67,8 +81,7 @@ def method(self) -> Optional[str]:

@method.setter
def method(self, method: Optional[str]):
if method is not None:
if method.lower() not in {"traditional", "adaptive"}:
raise ValueError("Invalid method.")
if method is not None and method.lower() not in {"traditional", "adaptive"}:
raise ValueError("Invalid method.")

self._method = method.lower() if method is not None else None
11 changes: 11 additions & 0 deletions tests/relation/test_descriptors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# TODO: add unit test for property: method

import pytest

from cat_scaling.relation import Descriptors


Expand All @@ -26,3 +28,12 @@ def test_init(self):

assert des_adap.descriptors == ["*CO", "*OH"]
assert len(des_adap) == 2

def test_member_overlap(self):
groups = {
"*CO": ["*COOH", "*CHO", "*CH2O"],
"*OH": ["*OCH3", "*O", "*CH2O"],
}

with pytest.warns(UserWarning, match="Descriptor group members overlap."):
Descriptors(groups, method="traditional")

0 comments on commit d777056

Please sign in to comment.