Skip to content

Commit

Permalink
format improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Apr 2, 2024
1 parent d777056 commit c75e6c3
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 32 deletions.
13 changes: 7 additions & 6 deletions cat_scaling/data/eads.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def data(
try:
data = data.astype(float)
except ValueError as e:
raise ValueError(f"Please double-check input data: {e}.")
raise ValueError(
"Please double-check input data for Eads."
) from e

else:
raise TypeError("Expect data as pd.DataFrame type.")
Expand Down Expand Up @@ -239,14 +241,13 @@ def remove_sample(

def sort_data(
self,
targets: list[str] = ["column", "row"],
targets: set[str] | None = None,
) -> None:
"""Sort columns/rows of data."""
if targets is None:
targets = {"column", "row"}

if not set(targets) <= {
"column",
"row",
}:
elif targets > {"column", "row"}:
raise ValueError(
"Invalid target values. Should be 'column', 'row', or both."
)
Expand Down
17 changes: 5 additions & 12 deletions cat_scaling/data/reaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,14 @@ def __str__(self) -> str:
# Add reactants
reactants = []
for spec, num in self.reactants.items():
if spec.adsorbed:
name = f"*{spec.name}"
else:
name = spec.name
name = f"*{spec.name}" if spec.adsorbed else spec.name

reactants.append(f"{num}{name}")

# Add products
products = []
for spec, num in self.products.items():
if spec.adsorbed:
name = f"*{spec.name}"
else:
name = spec.name
name = f"*{spec.name}" if spec.adsorbed else spec.name

products.append(f"{num}{name}")

Expand Down Expand Up @@ -150,10 +144,9 @@ def _sepa_stoi_number(name: str) -> tuple[float, str]:
name = name.strip()

# Use re to separate leading digits and name
match = re.match(r"^(\d+(\.\d+)?)(.*)$", name)
if match:
stoi_number_str = match.group(1)
species_name = match.group(3)
if match := re.match(r"^(\d+(\.\d+)?)(.*)$", name):
stoi_number_str = match[1]
species_name = match[3]

else:
stoi_number_str = ""
Expand Down
2 changes: 1 addition & 1 deletion cat_scaling/data/species.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def from_str(cls, string: str) -> Self:
string = string.strip()

# Check if adsorbed
adsorbed: bool = True if string.startswith("*") else False
adsorbed: bool = bool(string.startswith("*"))

# Get energy and correction
e_start = string.find("(")
Expand Down
2 changes: 1 addition & 1 deletion cat_scaling/plotters/volcanos.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def VolcanoPlotter2D(
y = np.linspace(*y_range)

# Generate limiting potential mesh
# TODO: indexes should be used for rate-determining step (RDS)
# TODO: indexes should be used for rate-determining step (RDS) plot
limit_potentials, _rds = relation.eval_limit_potential_2D(x, y)

# Generate limiting potential volcano plot
Expand Down
11 changes: 4 additions & 7 deletions cat_scaling/relation/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@ def _build_composite_descriptor(

# Fetch child descriptors
child_descriptors = np.array(
[
self.data.get_adsorbate(species)
for species in spec_ratios.keys()
]
[self.data.get_adsorbate(species) for species in spec_ratios]
)

# Construct composite descriptor (from child descriptors)
Expand Down Expand Up @@ -272,7 +269,7 @@ def build_adaptive(
warnings.warn("Small step length may slow down searching.")

# Convert step_length to percentage
step_length = step_length / 100
step_length /= 100

# Get descriptors as a list of names
_descriptors = descriptors.descriptors
Expand Down Expand Up @@ -304,8 +301,8 @@ def build_adaptive(

scores[ratio] = metrics

# Rerun linear regression with the optimal ratio
opt_ratio = max(scores, key=lambda k: scores[k])
# Find and rerun linear regression with the optimal ratio
opt_ratio = max(scores, key=scores.get) # type: ignore[arg-type]

opt_ratios = {
_descriptors[0]: opt_ratio,
Expand Down
16 changes: 15 additions & 1 deletion cat_scaling/relation/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@


class Descriptors:
"""Helper class to record descriptors.
Attributes:
groups (dict): A dictionary representing groups of adsorbates
and their respective descriptors.
Keys are descriptors and values are lists of group members.
If a group has no members, its value is None.
method (str, optional): The method used for building Relation.
Should be either "traditional" or "adaptive".
"""

def __init__(self, groups: dict, method: Optional[str] = None) -> None:
self.groups = groups
self.method = method
Expand Down Expand Up @@ -81,7 +92,10 @@ def method(self) -> Optional[str]:

@method.setter
def method(self, method: Optional[str]):
if method is not None and method.lower() not in {"traditional", "adaptive"}:
if method is not None and method.lower() not in {
"traditional",
"adaptive",
}:
raise ValueError("Invalid method.")

self._method = method.lower() if method is not None else None
3 changes: 1 addition & 2 deletions cat_scaling/relation/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def coefficients(self, coefficients: list[np.ndarray]):
raise TypeError("All coefficients should be numpy arrays.")

# Check if all arrays have the same length
if len(set(arr.shape[0] for arr in coefficients)) > 1:
if len({arr.shape[0] for arr in coefficients}) > 1:
raise ValueError(
"All coefficient arrays should have the same length."
)
Expand Down Expand Up @@ -331,7 +331,6 @@ def _eval_deltaE_2D(
# Extract coefficients
coef_x, coef_y, intercept = coef

# TODO: more efficient method/algorithm?
return xx * coef_x + yy * coef_y + intercept

def _eval_limit_potential_2D(
Expand Down
2 changes: 1 addition & 1 deletion tests/data/test_eads.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_remove_sample(self, setup_class):
assert "Cu@g-C3N4" not in self.eads.samples

def test_sort_df(self, setup_class):
self.eads.sort_data(targets=["column", "row"])
self.eads.sort_data(targets={"column", "row"})

assert self.eads.adsorbates == [
"*CO",
Expand Down
4 changes: 3 additions & 1 deletion tests/relation/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,7 @@ def test_member_overlap(self):
"*OH": ["*OCH3", "*O", "*CH2O"],
}

with pytest.warns(UserWarning, match="Descriptor group members overlap."):
with pytest.warns(
UserWarning, match="Descriptor group members overlap."
):
Descriptors(groups, method="traditional")

0 comments on commit c75e6c3

Please sign in to comment.