-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
This commit provides support for decorators usage with vaex, as well as a result builder to stitch everything back. Squashed commits: * Adds plugin, example and tests for Vaex (#457) * Small fix in description of VaexDataFrameResult class * Adds minor fixes for Vaex Adding to the registry so we try to load the extension if vaex is in the environment. Then fixing a few minor typos otherwise to get things to work. --------- Co-authored-by: Konstantin Tyapochkin <[email protected]>
- Loading branch information
1 parent
8a00d17
commit e2ea1d8
Showing
16 changed files
with
799 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import vaex | ||
|
||
from hamilton.function_modifiers import extract_columns | ||
|
||
|
||
@extract_columns("signups", "spend") | ||
def base_df(base_df_location: str) -> vaex.dataframe.DataFrame: | ||
"""Loads base dataframe of data. | ||
:param base_df_location: just showing that we could load this from a file... | ||
:return: | ||
""" | ||
return vaex.from_pandas( | ||
pd.DataFrame( | ||
{ | ||
"signups": [1, 10, 50, 100, 200, 400], | ||
"spend": [10, 10, 20, 40, 40, 50], | ||
} | ||
) | ||
) | ||
|
||
|
||
def spend_per_signup( | ||
spend: vaex.expression.Expression, signups: vaex.expression.Expression | ||
) -> vaex.expression.Expression: | ||
"""The cost per signup in relation to spend.""" | ||
return spend / signups | ||
|
||
|
||
def spend_mean(spend: vaex.expression.Expression) -> float: | ||
"""Shows function creating a scalar. In this case it computes the mean of the entire column.""" | ||
return spend.mean() | ||
|
||
|
||
def spend_zero_mean(spend: vaex.expression.Expression, spend_mean: float) -> np.ndarray: | ||
"""Shows function that takes a scalar and returns np.ndarray.""" | ||
return (spend - spend_mean).to_numpy() | ||
|
||
|
||
def spend_std_dev(spend: vaex.expression.Expression) -> float: | ||
"""Function that computes the standard deviation of the spend column.""" | ||
return spend.std() | ||
|
||
|
||
def spend_zero_mean_unit_variance(spend_zero_mean: np.ndarray, spend_std_dev: float) -> np.ndarray: | ||
"""Function showing one way to make spend have zero mean and unit variance.""" | ||
return spend_zero_mean / spend_std_dev |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import logging | ||
import sys | ||
|
||
from hamilton import base, driver | ||
from hamilton.plugins import h_vaex | ||
|
||
logging.basicConfig(stream=sys.stdout) | ||
|
||
# Create a driver instance. | ||
adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult()) | ||
config = { | ||
"base_df_location": "dummy_value", | ||
} | ||
|
||
# where our functions are defined | ||
import my_functions | ||
|
||
dr = driver.Driver(config, my_functions, adapter=adapter) | ||
output_columns = [ | ||
"spend", | ||
"signups", | ||
"spend_per_signup", | ||
"spend_std_dev", | ||
"spend_mean", | ||
"spend_zero_mean_unit_variance", | ||
] | ||
|
||
# let's create the dataframe! | ||
df = dr.execute(output_columns) | ||
print(df) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
sf-hamilton[vaex] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
"xgboost", | ||
"lightgbm", | ||
"sklearn_plot", | ||
"vaex", | ||
] | ||
for plugin_module in plugins_modules: | ||
try: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
from typing import Any, Dict, List, Type, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from hamilton import base | ||
|
||
try: | ||
import vaex | ||
except ImportError: | ||
raise NotImplementedError("Vaex is not installed.") | ||
|
||
|
||
class VaexDataFrameResult(base.ResultMixin): | ||
"""A ResultBuilder that produces a Vaex dataframe. | ||
Use this when you want to create a Vaex dataframe from the outputs. | ||
Caveat: you need to ensure that the length | ||
of the outputs is the same (except scalars), otherwise you will get an error; | ||
mixed outputs aren't that well handled. | ||
To use: | ||
.. code-block:: python | ||
from hamilton import base, driver | ||
from hamilton.plugins import h_vaex | ||
vaex_builder = h_vaex.VaexDataFrameResult() | ||
adapter = base.SimplePythonGraphAdapter(vaex_builder) | ||
dr = driver.Driver(config, *modules, adapter=adapter) | ||
df = dr.execute([...], inputs=...) # returns vaex dataframe | ||
Note: this is just a first attempt at something for Vaex. | ||
Think it should handle more? Come chat/open a PR! | ||
""" | ||
|
||
def build_result( | ||
self, | ||
**outputs: Dict[str, Union[vaex.expression.Expression, vaex.dataframe.DataFrame, Any]], | ||
): | ||
"""This is the method that Hamilton will call to build the final result. | ||
It will pass in the results of the requested outputs that | ||
you passed in to the execute() method. | ||
:param outputs: The results of the requested outputs. | ||
:return: a Vaex DataFrame. | ||
""" | ||
|
||
# We split all outputs into DataFrames, arrays and scalars | ||
dfs: List[vaex.dataframe.DataFrame] = [] # Vaex DataFrames from outputs | ||
arrays: Dict[str, np.ndarray] = dict() # arrays from outputs | ||
scalars: Dict[str, Any] = dict() # scalars from outputs | ||
|
||
for name, value in outputs.items(): | ||
if isinstance(value, vaex.dataframe.DataFrame): | ||
dfs.append(value) | ||
elif isinstance(value, vaex.expression.Expression): | ||
nparray = value.to_numpy() | ||
if nparray.ndim == 0: # value is scalar | ||
scalars[name] = nparray.item() | ||
elif nparray.shape == (1,): # value is scalar | ||
scalars[name] = nparray[0] | ||
else: # value is array | ||
arrays[name] = nparray | ||
elif isinstance(value, np.ndarray): | ||
if value.ndim == 0: # value is scalar | ||
scalars[name] = value.item() | ||
elif value.shape == (1,): # value is scalar | ||
scalars[name] = value[0] | ||
else: # value is array | ||
arrays[name] = value | ||
elif pd.api.types.is_scalar(value): # value is scalar | ||
scalars[name] = value | ||
else: | ||
value_type = str(type(value)) | ||
message = f"VaexDataFrameResult doesn't support {value_type}" | ||
raise NotImplementedError(message) | ||
|
||
df = None | ||
|
||
if arrays: | ||
|
||
# Check if all arrays have correct and identical shapes. | ||
first_expression_shape = next(arrays.values().__iter__()).shape | ||
if len(first_expression_shape) > 1: | ||
raise NotImplementedError( | ||
"VaexDataFrameResult supports only one-dimensional Expression results" | ||
) | ||
for name, a in arrays.items(): | ||
if a.shape != first_expression_shape: | ||
raise NotImplementedError( | ||
"VaexDataFrameResult supports Expression results with same dimension only" | ||
) | ||
|
||
# All scalars become arrays with the same shape as other arrays. | ||
for name, scalar in scalars.items(): | ||
arrays[name] = np.full(first_expression_shape, scalar) | ||
|
||
df = vaex.from_arrays(**arrays) | ||
|
||
elif scalars: | ||
|
||
# There are not arrays in outputs, | ||
# so we construct Vaex DataFrame with one row consisting of scalars. | ||
df = vaex.from_arrays(**{name: np.array([value]) for name, value in scalars.items()}) | ||
|
||
if df: | ||
dfs.append(df) | ||
|
||
return vaex.concat(dfs) | ||
|
||
def output_type(self) -> Type: | ||
return vaex.dataframe.DataFrame |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Any | ||
|
||
import numpy as np | ||
|
||
from hamilton import registry | ||
|
||
try: | ||
import vaex | ||
except ImportError: | ||
raise NotImplementedError("Vaex is not installed.") | ||
|
||
DATAFRAME_TYPE = vaex.dataframe.DataFrame | ||
COLUMN_TYPE = vaex.expression.Expression | ||
|
||
|
||
@registry.get_column.register(vaex.dataframe.DataFrame) | ||
def get_column_vaex(df: vaex.dataframe.DataFrame, column_name: str) -> vaex.expression.Expression: | ||
return df[column_name] | ||
|
||
|
||
@registry.fill_with_scalar.register(vaex.dataframe.DataFrame) | ||
def fill_with_scalar_vaex( | ||
df: vaex.dataframe.DataFrame, column_name: str, scalar_value: Any | ||
) -> vaex.dataframe.DataFrame: | ||
df[column_name] = np.full((df.shape[0],), scalar_value) | ||
return df | ||
|
||
|
||
def register_types(): | ||
"""Function to register the types for this extension.""" | ||
registry.register_types("vaex", DATAFRAME_TYPE, COLUMN_TYPE) | ||
|
||
|
||
register_types() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from hamilton import telemetry | ||
|
||
# disable telemetry for all tests! | ||
telemetry.disable_telemetry() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import vaex | ||
|
||
from hamilton.function_modifiers import extract_columns | ||
|
||
|
||
@extract_columns("a", "b") | ||
def generate_df() -> vaex.dataframe.DataFrame: | ||
return vaex.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [2, 3, 4]})) | ||
|
||
|
||
def a_plus_b_expression( | ||
a: vaex.expression.Expression, b: vaex.expression.Expression | ||
) -> vaex.expression.Expression: | ||
return a + b | ||
|
||
|
||
def a_plus_b_nparray(a: vaex.expression.Expression, b: vaex.expression.Expression) -> np.ndarray: | ||
return (a + b).to_numpy() | ||
|
||
|
||
def a_mean(a: vaex.expression.Expression) -> float: | ||
return a.mean() | ||
|
||
|
||
def b_mean(b: vaex.expression.Expression) -> float: | ||
return b.mean() | ||
|
||
|
||
def ab_as_df( | ||
a: vaex.expression.Expression, b: vaex.expression.Expression | ||
) -> vaex.dataframe.DataFrame: | ||
return vaex.from_pandas(pd.DataFrame({"a_in_df": a.to_numpy(), "b_in_df": b.to_numpy()})) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import numpy as np | ||
import vaex | ||
|
||
from hamilton import base, driver | ||
from hamilton.plugins import h_vaex, vaex_extensions # noqa F401 | ||
|
||
from .resources import functions | ||
|
||
|
||
def test_vaex_column_from_expression(): | ||
adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult()) | ||
dr = driver.Driver({}, functions, adapter=adapter) | ||
result_df = dr.execute(["a", "b", "a_plus_b_expression"]) | ||
assert isinstance(result_df, vaex.dataframe.DataFrame) | ||
np.testing.assert_allclose(result_df["a_plus_b_expression"].to_numpy(), [3, 5, 7]) | ||
|
||
|
||
def test_vaex_column_from_nparray(): | ||
adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult()) | ||
dr = driver.Driver({}, functions, adapter=adapter) | ||
result_df = dr.execute(["a", "b", "a_plus_b_nparray"]) | ||
assert isinstance(result_df, vaex.dataframe.DataFrame) | ||
np.testing.assert_allclose(result_df["a_plus_b_nparray"].to_numpy(), [3, 5, 7]) | ||
|
||
|
||
def test_vaex_scalar_among_columns(): | ||
adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult()) | ||
dr = driver.Driver({}, functions, adapter=adapter) | ||
result_df = dr.execute(["a", "b", "a_mean"]) | ||
assert isinstance(result_df, vaex.dataframe.DataFrame) | ||
np.testing.assert_allclose(result_df["a_mean"].to_numpy(), [2, 2, 2]) | ||
|
||
|
||
def test_vaex_only_scalars(): | ||
adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult()) | ||
dr = driver.Driver({}, functions, adapter=adapter) | ||
result_df = dr.execute(["a_mean", "b_mean"]) | ||
assert isinstance(result_df, vaex.dataframe.DataFrame) | ||
np.testing.assert_allclose(result_df["a_mean"].to_numpy(), [2]) | ||
np.testing.assert_allclose(result_df["b_mean"].to_numpy(), [3]) | ||
|
||
|
||
def test_vaex_df_among_columns(): | ||
adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult()) | ||
dr = driver.Driver({}, functions, adapter=adapter) | ||
result_df = dr.execute(["a", "b", "ab_as_df"]) | ||
assert isinstance(result_df, vaex.dataframe.DataFrame) | ||
np.testing.assert_allclose(result_df["a_in_df"].to_numpy(), result_df["a"].to_numpy()) | ||
np.testing.assert_allclose(result_df["b_in_df"].to_numpy(), result_df["b"].to_numpy()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters