Skip to content

Commit

Permalink
Categoricals: correctly handle missing values, speed up simple string…
Browse files Browse the repository at this point in the history
… comparisons (#46)

* make isna work

* categorical comparisons

* warn on bad categories

* NotEq
  • Loading branch information
jpn-- authored Mar 21, 2024
1 parent 6dc43a8 commit a54cfb3
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 1 deletion.
98 changes: 97 additions & 1 deletion sharrow/aster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import io
import logging
import tokenize
import warnings

import numpy as np

try:
from ast import unparse
Expand Down Expand Up @@ -882,7 +885,7 @@ def visit_Call(self, node):
left=ante, ops=[ast.Eq()], comparators=[self.visit(elt)]
)
)
result = ast.BoolOp(op=ast.Or(), values=ors)
result = self.visit(ast.BoolOp(op=ast.Or(), values=ors))
# change `x.between(a,b)` to `(a <= x) & (x <= b)`
if isinstance(node.func, ast.Attribute) and node.func.attr == "between":
ante = self.visit(node.func.value)
Expand Down Expand Up @@ -910,6 +913,19 @@ def visit_Call(self, node):
args=apply_args,
keywords=[],
)
# if the value XXX is a categorical, rather than just `np.isnan` we also
# want to check if the category number is less than zero
if isinstance(apply_args[0], ast.Subscript):
if isinstance(apply_args[0].value, ast.Name):
if apply_args[0].value.id.startswith("__encoding_dict__"):
cat_is_lt_zero = ast.Compare(
left=apply_args[0].slice,
ops=[ast.Lt()],
comparators=[ast.Num(0)],
)
result = ast.BoolOp(
op=ast.Or(), values=[cat_is_lt_zero, result]
)

# implement x.get("y",z) where x is the spacename
if (
Expand Down Expand Up @@ -958,6 +974,86 @@ def visit_Call(self, node):
self.log_event("visit_Call", node, result)
return result

def visit_Compare(self, node):
result = None

# devolve XXX ==/!= YYY when XXX or YYY is a categorical and the other is a constant category value
if len(node.ops) == 1 and isinstance(node.ops[0], (ast.Eq, ast.NotEq)):
left = self.visit(node.left)
right = self.visit(node.comparators[0])
left_is_categorical = (
isinstance(left, ast.Subscript)
and isinstance(left.value, ast.Name)
and left.value.id.startswith("__encoding_dict__")
)
if left_is_categorical and isinstance(right, ast.Constant):
# left is categorical, right is a constant
left_spacename = left.value.id.split("__")[2]
left_varname = left.value.id.split("__")[3]
if (
left_spacename == self.spacename
and left_varname in self.digital_encodings
):
left_dictionary = self.digital_encodings[left_varname].get(
"dictionary", np.atleast_1d([])
)
try:
right_decoded = np.where(left_dictionary == right.value)[0][0]
except IndexError:
right_decoded = None
warnings.warn(
f"right hand value {right.value!r} not found in "
f"categories for {left_varname} in {self.spacename}",
stacklevel=2,
)
if right_decoded is not None:
result = ast.Compare(
left=left.slice,
ops=[self.visit(i) for i in node.ops],
comparators=[ast_Constant(right_decoded)],
)
right_is_categorical = (
isinstance(right, ast.Subscript)
and isinstance(right.value, ast.Name)
and right.value.id.startswith("__encoding_dict__")
)
if right_is_categorical and isinstance(left, ast.Constant):
# right is categorical, left is a constant
right_spacename = right.value.id.split("__")[2]
right_varname = right.value.id.split("__")[3]
if (
right_spacename == self.spacename
and right_varname in self.digital_encodings
):
right_dictionary = self.digital_encodings[right_varname].get(
"dictionary", np.atleast_1d([])
)
try:
left_decoded = np.where(right_dictionary == left.value)[0][0]
except IndexError:
left_decoded = None
warnings.warn(
f"left hand value {left.value!r} not found in "
f"categories for {right_varname} in {self.spacename}",
stacklevel=2,
)
if left_decoded is not None:
result = ast.Compare(
left=ast_Constant(left_decoded),
ops=[self.visit(i) for i in node.ops],
comparators=[right.slice],
)

# if no other changes
if result is None:
result = ast.Compare(
left=self.visit(node.left),
ops=[self.visit(i) for i in node.ops],
comparators=[self.visit(i) for i in node.comparators],
)
self.log_event("visit_Compare", node, result)
return result


def expression_for_numba(
expr,
Expand Down
61 changes: 61 additions & 0 deletions sharrow/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
import pytest
import xarray as xr

import sharrow
Expand Down Expand Up @@ -115,3 +116,63 @@ class TourMode(IntEnum):
assert df["TourMode2"].dtype == "category"
assert all(df["TourMode2"].cat.categories == ["_0", "Car", "Bus", "Walk"])
assert all(df["TourMode2"].cat.codes == [1, 2, 1, 1, 3])


def test_missing_categorical():
df = pd.DataFrame(
{
"TourMode": ["Car", "Bus", "Car", "Car", "Walk", np.nan],
"person_id": [441, 445, 552, 556, 934, 998],
},
index=pd.Index([4411, 4451, 5521, 5561, 9341, 9981], name="tour_id"),
)
df["TourMode2"] = df["TourMode"].astype(pd.CategoricalDtype(["Car", "Bus", "Walk"]))
assert df["TourMode2"].dtype == "category"
assert all(df["TourMode2"].cat.categories == ["Car", "Bus", "Walk"])
assert all(df["TourMode2"].cat.codes == [0, 1, 0, 0, 2, -1])

tree = sharrow.DataTree(df=df, root_node_name=False)

expr = "df.TourMode2 == 'Bus'"
f = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = f.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([0, 1, 0, 0, 0, 0]))

expr = "df.TourMode2.isna()"
f2 = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = f2.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([0, 0, 0, 0, 0, 1]))

expr = "df.TourMode2 == 'Walk'"
f3 = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = f3.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([0, 0, 0, 0, 1, 0]))

expr = "'Walk' == df.TourMode2"
f4 = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = f4.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([0, 0, 0, 0, 1, 0]))

expr = "df.TourMode2 == 'BAD'"
with pytest.warns(UserWarning):
f5 = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = f5.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([0, 0, 0, 0, 0, 0]))

expr = "'BAD' == df.TourMode2"
with pytest.warns(UserWarning):
f6 = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = f6.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([0, 0, 0, 0, 0, 0]))

expr = "df.TourMode2 != 'Bus'"
f7 = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = f7.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([1, 0, 1, 1, 1, 1]))

0 comments on commit a54cfb3

Please sign in to comment.