diff --git a/sharrow/aster.py b/sharrow/aster.py index c85f0f9..94ae753 100755 --- a/sharrow/aster.py +++ b/sharrow/aster.py @@ -2,6 +2,9 @@ import io import logging import tokenize +import warnings + +import numpy as np try: from ast import unparse @@ -882,7 +885,7 @@ def visit_Call(self, node): left=ante, ops=[ast.Eq()], comparators=[self.visit(elt)] ) ) - result = ast.BoolOp(op=ast.Or(), values=ors) + result = self.visit(ast.BoolOp(op=ast.Or(), values=ors)) # change `x.between(a,b)` to `(a <= x) & (x <= b)` if isinstance(node.func, ast.Attribute) and node.func.attr == "between": ante = self.visit(node.func.value) @@ -910,6 +913,19 @@ def visit_Call(self, node): args=apply_args, keywords=[], ) + # if the value XXX is a categorical, rather than just `np.isnan` we also + # want to check if the category number is less than zero + if isinstance(apply_args[0], ast.Subscript): + if isinstance(apply_args[0].value, ast.Name): + if apply_args[0].value.id.startswith("__encoding_dict__"): + cat_is_lt_zero = ast.Compare( + left=apply_args[0].slice, + ops=[ast.Lt()], + comparators=[ast.Num(0)], + ) + result = ast.BoolOp( + op=ast.Or(), values=[cat_is_lt_zero, result] + ) # implement x.get("y",z) where x is the spacename if ( @@ -958,6 +974,86 @@ def visit_Call(self, node): self.log_event("visit_Call", node, result) return result + def visit_Compare(self, node): + result = None + + # devolve XXX ==/!= YYY when XXX or YYY is a categorical and the other is a constant category value + if len(node.ops) == 1 and isinstance(node.ops[0], (ast.Eq, ast.NotEq)): + left = self.visit(node.left) + right = self.visit(node.comparators[0]) + left_is_categorical = ( + isinstance(left, ast.Subscript) + and isinstance(left.value, ast.Name) + and left.value.id.startswith("__encoding_dict__") + ) + if left_is_categorical and isinstance(right, ast.Constant): + # left is categorical, right is a constant + left_spacename = left.value.id.split("__")[2] + left_varname = left.value.id.split("__")[3] + if ( + left_spacename == self.spacename + and left_varname in self.digital_encodings + ): + left_dictionary = self.digital_encodings[left_varname].get( + "dictionary", np.atleast_1d([]) + ) + try: + right_decoded = np.where(left_dictionary == right.value)[0][0] + except IndexError: + right_decoded = None + warnings.warn( + f"right hand value {right.value!r} not found in " + f"categories for {left_varname} in {self.spacename}", + stacklevel=2, + ) + if right_decoded is not None: + result = ast.Compare( + left=left.slice, + ops=[self.visit(i) for i in node.ops], + comparators=[ast_Constant(right_decoded)], + ) + right_is_categorical = ( + isinstance(right, ast.Subscript) + and isinstance(right.value, ast.Name) + and right.value.id.startswith("__encoding_dict__") + ) + if right_is_categorical and isinstance(left, ast.Constant): + # right is categorical, left is a constant + right_spacename = right.value.id.split("__")[2] + right_varname = right.value.id.split("__")[3] + if ( + right_spacename == self.spacename + and right_varname in self.digital_encodings + ): + right_dictionary = self.digital_encodings[right_varname].get( + "dictionary", np.atleast_1d([]) + ) + try: + left_decoded = np.where(right_dictionary == left.value)[0][0] + except IndexError: + left_decoded = None + warnings.warn( + f"left hand value {left.value!r} not found in " + f"categories for {right_varname} in {self.spacename}", + stacklevel=2, + ) + if left_decoded is not None: + result = ast.Compare( + left=ast_Constant(left_decoded), + ops=[self.visit(i) for i in node.ops], + comparators=[right.slice], + ) + + # if no other changes + if result is None: + result = ast.Compare( + left=self.visit(node.left), + ops=[self.visit(i) for i in node.ops], + comparators=[self.visit(i) for i in node.comparators], + ) + self.log_event("visit_Compare", node, result) + return result + def expression_for_numba( expr, diff --git a/sharrow/tests/test_categorical.py b/sharrow/tests/test_categorical.py index b500112..247d09e 100644 --- a/sharrow/tests/test_categorical.py +++ b/sharrow/tests/test_categorical.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd +import pytest import xarray as xr import sharrow @@ -115,3 +116,63 @@ class TourMode(IntEnum): assert df["TourMode2"].dtype == "category" assert all(df["TourMode2"].cat.categories == ["_0", "Car", "Bus", "Walk"]) assert all(df["TourMode2"].cat.codes == [1, 2, 1, 1, 3]) + + +def test_missing_categorical(): + df = pd.DataFrame( + { + "TourMode": ["Car", "Bus", "Car", "Car", "Walk", np.nan], + "person_id": [441, 445, 552, 556, 934, 998], + }, + index=pd.Index([4411, 4451, 5521, 5561, 9341, 9981], name="tour_id"), + ) + df["TourMode2"] = df["TourMode"].astype(pd.CategoricalDtype(["Car", "Bus", "Walk"])) + assert df["TourMode2"].dtype == "category" + assert all(df["TourMode2"].cat.categories == ["Car", "Bus", "Walk"]) + assert all(df["TourMode2"].cat.codes == [0, 1, 0, 0, 2, -1]) + + tree = sharrow.DataTree(df=df, root_node_name=False) + + expr = "df.TourMode2 == 'Bus'" + f = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = f.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([0, 1, 0, 0, 0, 0])) + + expr = "df.TourMode2.isna()" + f2 = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = f2.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([0, 0, 0, 0, 0, 1])) + + expr = "df.TourMode2 == 'Walk'" + f3 = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = f3.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([0, 0, 0, 0, 1, 0])) + + expr = "'Walk' == df.TourMode2" + f4 = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = f4.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([0, 0, 0, 0, 1, 0])) + + expr = "df.TourMode2 == 'BAD'" + with pytest.warns(UserWarning): + f5 = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = f5.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([0, 0, 0, 0, 0, 0])) + + expr = "'BAD' == df.TourMode2" + with pytest.warns(UserWarning): + f6 = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = f6.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([0, 0, 0, 0, 0, 0])) + + expr = "df.TourMode2 != 'Bus'" + f7 = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = f7.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([1, 0, 1, 1, 1, 1]))