Skip to content

Commit

Permalink
Added last workflow validation (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
canimus authored Oct 1, 2023
1 parent e245850 commit d319c4c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 22 deletions.
34 changes: 12 additions & 22 deletions cuallee/polars_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,28 +290,18 @@ def has_workflow(self, rule: Rule, dataframe: pl.DataFrame) -> Union[bool, int]:

def workflow(dataframe):
group, event, order = rule.column
CUALLEE_EVENT = "cuallee_event"
CUALLEE_EDGE = "cuallee_edge"
CUALLEE_GRAPH = "cuallee_graph"
dataframe[CUALLEE_EVENT] = (
dataframe.loc[:, rule.column]
.sort_values(by=[group, order], ascending=True)
.groupby([group])[event]
.shift(-1)
.replace(np.nan, None)
)
dataframe[CUALLEE_EDGE] = dataframe[[event, CUALLEE_EVENT]].apply(
lambda x: (x[event], x[CUALLEE_EVENT]), axis=1
)
dataframe[CUALLEE_GRAPH] = list(repeat(rule.value, len(dataframe)))

return (
dataframe.apply(lambda x: x[CUALLEE_EDGE] in x[CUALLEE_GRAPH], axis=1)
.astype("int")
.sum()
)

return workflow(dataframe.loc[:, rule.column])
groups = dataframe.partition_by(group)
interactions = []
_d = compose(list, operator.methodcaller("values"), operator.methodcaller("to_dict", as_series=False))
for g in groups:
pairs = list(zip(*_d(g.select(pl.col(event), pl.col(event).shift(-1).alias("target")))))
if result := set(pairs).difference(rule.value):
for t in result:
interactions.append(t)

return len(dataframe) - len(interactions)

return workflow(dataframe.select(*rule.column))


def compute(rules: Dict[str, Rule]):
Expand Down
30 changes: 30 additions & 0 deletions test/unit/polars_dataframe/test_has_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import polars as pl
from cuallee import Check
import pytest


def test_positive(check: Check):
check.has_workflow("group", "event", "order", [("x", "y"), ("y", "z"), ("z", None)])
df = pl.DataFrame(
{"group": list("AAABBB"), "event": list("xyzxyz"), "order": [1, 2, 3, 1, 2, 3]}
)
result = check.validate(df).select(pl.col("status")) == "PASS"
assert all(result.to_series().to_list())


def test_negative(check: Check):
check.has_workflow("group", "event", "order", [("x", "y"), ("y", "z")])
df = pl.DataFrame(
{"group": list("AAABBB"), "event": list("xyzxyz"), "order": [1, 2, 3, 1, 2, 3]}
)
result = check.validate(df).select(pl.col("status")) == "FAIL"
assert all(result.to_series().to_list())


def test_coverage(check: Check):
check.has_workflow("group", "event", "order", [("x", "y"), ("y", "z")], pct=4 / 6)
df = pl.DataFrame(
{"group": list("AAABBB"), "event": list("xyzxyz"), "order": [1, 2, 3, 1, 2, 3]}
)
result = check.validate(df).select(pl.col("status")) == "PASS"
assert all(result.to_series().to_list())

0 comments on commit d319c4c

Please sign in to comment.