diff --git a/.ci/setup.sh b/.ci/setup.sh
index ce358e47c..f123d62f0 100755
--- a/.ci/setup.sh
+++ b/.ci/setup.sh
@@ -35,6 +35,15 @@ if [[ ${TASK} == "pyspark" ]]; then
fi
fi
+if [[ ${TASK} == "vaex" ]]; then
+ if [[ ${OPERATING_SYSTEM} == "Linux" ]]; then
+ sudo apt-get install \
+ --no-install-recommends \
+ --yes \
+ libpcre3-dev cargo
+ fi
+fi
+
echo "----- python version -----"
python --version
diff --git a/.ci/test.sh b/.ci/test.sh
index aff9ad210..ab5f13661 100755
--- a/.ci/test.sh
+++ b/.ci/test.sh
@@ -45,6 +45,12 @@ if [[ ${TASK} == "pyspark" ]]; then
exit 0
fi
+if [[ ${TASK} == "vaex" ]]; then
+ pip install -e '.[vaex]'
+ pytest plugin_tests/h_vaex
+ exit 0
+fi
+
if [[ ${TASK} == "tests" ]]; then
pip install .
pytest \
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 4ec6641df..efbfe69bb 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -95,6 +95,12 @@ workflows:
name: ray-py39
python-version: '3.9'
task: ray
+ - test:
+ requires:
+ - check_for_changes
+ name: vaex-py39
+ python-version: '3.9'
+ task: vaex
- test:
requires:
- check_for_changes
diff --git a/examples/vaex/my_functions.py b/examples/vaex/my_functions.py
new file mode 100644
index 000000000..f45dfcb37
--- /dev/null
+++ b/examples/vaex/my_functions.py
@@ -0,0 +1,49 @@
+import numpy as np
+import pandas as pd
+import vaex
+
+from hamilton.function_modifiers import extract_columns
+
+
+@extract_columns("signups", "spend")
+def base_df(base_df_location: str) -> vaex.dataframe.DataFrame:
+ """Loads base dataframe of data.
+
+ :param base_df_location: just showing that we could load this from a file...
+ :return:
+ """
+ return vaex.from_pandas(
+ pd.DataFrame(
+ {
+ "signups": [1, 10, 50, 100, 200, 400],
+ "spend": [10, 10, 20, 40, 40, 50],
+ }
+ )
+ )
+
+
+def spend_per_signup(
+ spend: vaex.expression.Expression, signups: vaex.expression.Expression
+) -> vaex.expression.Expression:
+ """The cost per signup in relation to spend."""
+ return spend / signups
+
+
+def spend_mean(spend: vaex.expression.Expression) -> float:
+ """Shows function creating a scalar. In this case it computes the mean of the entire column."""
+ return spend.mean()
+
+
+def spend_zero_mean(spend: vaex.expression.Expression, spend_mean: float) -> np.ndarray:
+ """Shows function that takes a scalar and returns np.ndarray."""
+ return (spend - spend_mean).to_numpy()
+
+
+def spend_std_dev(spend: vaex.expression.Expression) -> float:
+ """Function that computes the standard deviation of the spend column."""
+ return spend.std()
+
+
+def spend_zero_mean_unit_variance(spend_zero_mean: np.ndarray, spend_std_dev: float) -> np.ndarray:
+ """Function showing one way to make spend have zero mean and unit variance."""
+ return spend_zero_mean / spend_std_dev
diff --git a/examples/vaex/my_script.py b/examples/vaex/my_script.py
new file mode 100644
index 000000000..eb5464ca8
--- /dev/null
+++ b/examples/vaex/my_script.py
@@ -0,0 +1,30 @@
+import logging
+import sys
+
+from hamilton import base, driver
+from hamilton.plugins import h_vaex
+
+logging.basicConfig(stream=sys.stdout)
+
+# Create a driver instance.
+adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult())
+config = {
+ "base_df_location": "dummy_value",
+}
+
+# where our functions are defined
+import my_functions
+
+dr = driver.Driver(config, my_functions, adapter=adapter)
+output_columns = [
+ "spend",
+ "signups",
+ "spend_per_signup",
+ "spend_std_dev",
+ "spend_mean",
+ "spend_zero_mean_unit_variance",
+]
+
+# let's create the dataframe!
+df = dr.execute(output_columns)
+print(df)
diff --git a/examples/vaex/notebook.ipynb b/examples/vaex/notebook.ipynb
new file mode 100644
index 000000000..e11a604c8
--- /dev/null
+++ b/examples/vaex/notebook.ipynb
@@ -0,0 +1,459 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/konstantin/.miniconda3/envs/hamilton-vaex/lib/python3.9/site-packages/dask/dataframe/_pyarrow_compat.py:17: FutureWarning: Minimal version of pyarrow will soon be increased to 14.0.1. You are using 11.0.0. Please consider upgrading.\n",
+ " warnings.warn(\n",
+ "/Users/konstantin/.miniconda3/envs/hamilton-vaex/lib/python3.9/site-packages/scipy/__init__.py:155: UserWarning: A NumPy version >=1.18.5 and <1.25.0 is required for this version of SciPy (detected version 1.26.3\n",
+ " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
+ ]
+ }
+ ],
+ "source": [
+ "from hamilton import base, driver\n",
+ "from hamilton.plugins import vaex_extensions, h_vaex"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import my_functions\n",
+ "\n",
+ "config = {\n",
+ " \"base_df_location\": \"dummy_value\",\n",
+ "}\n",
+ "adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult())\n",
+ "dr = driver.Driver(config, my_functions, adapter=adapter)\n",
+ "output_columns = [\n",
+ " \"spend\",\n",
+ " \"signups\",\n",
+ " \"spend_per_signup\",\n",
+ " \"spend_std_dev\",\n",
+ " \"spend_mean\",\n",
+ " \"spend_zero_mean_unit_variance\",\n",
+ "]\n",
+ "df = dr.execute(output_columns)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " # spend signups spend_per_signup spend_zero_mean_unit_variance spend_std_dev spend_mean\n",
+ " 0 10 1 10 -1.166 15.7233 28.3333\n",
+ " 1 10 10 1 -1.166 15.7233 28.3333\n",
+ " 2 20 50 0.4 -0.529999 15.7233 28.3333\n",
+ " 3 40 100 0.4 0.741999 15.7233 28.3333\n",
+ " 4 40 200 0.2 0.741999 15.7233 28.3333\n",
+ " 5 50 400 0.125 1.378 15.7233 28.3333\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(df)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dr.visualize_execution(output_columns)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dr.display_all_functions()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "hamilton-vaex",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/vaex/requirements.txt b/examples/vaex/requirements.txt
new file mode 100644
index 000000000..974815425
--- /dev/null
+++ b/examples/vaex/requirements.txt
@@ -0,0 +1 @@
+sf-hamilton[vaex]
diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py
index bcaad3245..3d01eb931 100644
--- a/hamilton/function_modifiers/base.py
+++ b/hamilton/function_modifiers/base.py
@@ -33,6 +33,7 @@
"xgboost",
"lightgbm",
"sklearn_plot",
+ "vaex",
]
for plugin_module in plugins_modules:
try:
diff --git a/hamilton/plugins/h_vaex.py b/hamilton/plugins/h_vaex.py
new file mode 100644
index 000000000..f80b9ec25
--- /dev/null
+++ b/hamilton/plugins/h_vaex.py
@@ -0,0 +1,113 @@
+from typing import Any, Dict, List, Type, Union
+
+import numpy as np
+import pandas as pd
+
+from hamilton import base
+
+try:
+ import vaex
+except ImportError:
+ raise NotImplementedError("Vaex is not installed.")
+
+
+class VaexDataFrameResult(base.ResultMixin):
+ """A ResultBuilder that produces a Vaex dataframe.
+
+ Use this when you want to create a Vaex dataframe from the outputs.
+ Caveat: you need to ensure that the length
+ of the outputs is the same (except scalars), otherwise you will get an error;
+ mixed outputs aren't that well handled.
+
+ To use:
+
+ .. code-block:: python
+
+ from hamilton import base, driver
+ from hamilton.plugins import h_vaex
+ vaex_builder = h_vaex.VaexDataFrameResult()
+ adapter = base.SimplePythonGraphAdapter(vaex_builder)
+ dr = driver.Driver(config, *modules, adapter=adapter)
+ df = dr.execute([...], inputs=...) # returns vaex dataframe
+
+ Note: this is just a first attempt at something for Vaex.
+ Think it should handle more? Come chat/open a PR!
+ """
+
+ def build_result(
+ self,
+ **outputs: Dict[str, Union[vaex.expression.Expression, vaex.dataframe.DataFrame, Any]],
+ ):
+ """This is the method that Hamilton will call to build the final result.
+ It will pass in the results of the requested outputs that
+ you passed in to the execute() method.
+
+ :param outputs: The results of the requested outputs.
+ :return: a Vaex DataFrame.
+ """
+
+ # We split all outputs into DataFrames, arrays and scalars
+ dfs: List[vaex.dataframe.DataFrame] = [] # Vaex DataFrames from outputs
+ arrays: Dict[str, np.ndarray] = dict() # arrays from outputs
+ scalars: Dict[str, Any] = dict() # scalars from outputs
+
+ for name, value in outputs.items():
+ if isinstance(value, vaex.dataframe.DataFrame):
+ dfs.append(value)
+ elif isinstance(value, vaex.expression.Expression):
+ nparray = value.to_numpy()
+ if nparray.ndim == 0: # value is scalar
+ scalars[name] = nparray.item()
+ elif nparray.shape == (1,): # value is scalar
+ scalars[name] = nparray[0]
+ else: # value is array
+ arrays[name] = nparray
+ elif isinstance(value, np.ndarray):
+ if value.ndim == 0: # value is scalar
+ scalars[name] = value.item()
+ elif value.shape == (1,): # value is scalar
+ scalars[name] = value[0]
+ else: # value is array
+ arrays[name] = value
+ elif pd.api.types.is_scalar(value): # value is scalar
+ scalars[name] = value
+ else:
+ value_type = str(type(value))
+ message = f"VaexDataFrameResult doesn't support {value_type}"
+ raise NotImplementedError(message)
+
+ df = None
+
+ if arrays:
+
+ # Check if all arrays have correct and identical shapes.
+ first_expression_shape = next(arrays.values().__iter__()).shape
+ if len(first_expression_shape) > 1:
+ raise NotImplementedError(
+ "VaexDataFrameResult supports only one-dimensional Expression results"
+ )
+ for name, a in arrays.items():
+ if a.shape != first_expression_shape:
+ raise NotImplementedError(
+ "VaexDataFrameResult supports Expression results with same dimension only"
+ )
+
+ # All scalars become arrays with the same shape as other arrays.
+ for name, scalar in scalars.items():
+ arrays[name] = np.full(first_expression_shape, scalar)
+
+ df = vaex.from_arrays(**arrays)
+
+ elif scalars:
+
+ # There are not arrays in outputs,
+ # so we construct Vaex DataFrame with one row consisting of scalars.
+ df = vaex.from_arrays(**{name: np.array([value]) for name, value in scalars.items()})
+
+ if df:
+ dfs.append(df)
+
+ return vaex.concat(dfs)
+
+ def output_type(self) -> Type:
+ return vaex.dataframe.DataFrame
diff --git a/hamilton/plugins/vaex_extensions.py b/hamilton/plugins/vaex_extensions.py
new file mode 100644
index 000000000..2081ed4ef
--- /dev/null
+++ b/hamilton/plugins/vaex_extensions.py
@@ -0,0 +1,34 @@
+from typing import Any
+
+import numpy as np
+
+from hamilton import registry
+
+try:
+ import vaex
+except ImportError:
+ raise NotImplementedError("Vaex is not installed.")
+
+DATAFRAME_TYPE = vaex.dataframe.DataFrame
+COLUMN_TYPE = vaex.expression.Expression
+
+
+@registry.get_column.register(vaex.dataframe.DataFrame)
+def get_column_vaex(df: vaex.dataframe.DataFrame, column_name: str) -> vaex.expression.Expression:
+ return df[column_name]
+
+
+@registry.fill_with_scalar.register(vaex.dataframe.DataFrame)
+def fill_with_scalar_vaex(
+ df: vaex.dataframe.DataFrame, column_name: str, scalar_value: Any
+) -> vaex.dataframe.DataFrame:
+ df[column_name] = np.full((df.shape[0],), scalar_value)
+ return df
+
+
+def register_types():
+ """Function to register the types for this extension."""
+ registry.register_types("vaex", DATAFRAME_TYPE, COLUMN_TYPE)
+
+
+register_types()
diff --git a/plugin_tests/h_vaex/__init__.py b/plugin_tests/h_vaex/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/plugin_tests/h_vaex/conftest.py b/plugin_tests/h_vaex/conftest.py
new file mode 100644
index 000000000..bc5ef5b5a
--- /dev/null
+++ b/plugin_tests/h_vaex/conftest.py
@@ -0,0 +1,4 @@
+from hamilton import telemetry
+
+# disable telemetry for all tests!
+telemetry.disable_telemetry()
diff --git a/plugin_tests/h_vaex/resources/__init__.py b/plugin_tests/h_vaex/resources/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/plugin_tests/h_vaex/resources/functions.py b/plugin_tests/h_vaex/resources/functions.py
new file mode 100644
index 000000000..e2f14f657
--- /dev/null
+++ b/plugin_tests/h_vaex/resources/functions.py
@@ -0,0 +1,34 @@
+import numpy as np
+import pandas as pd
+import vaex
+
+from hamilton.function_modifiers import extract_columns
+
+
+@extract_columns("a", "b")
+def generate_df() -> vaex.dataframe.DataFrame:
+ return vaex.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [2, 3, 4]}))
+
+
+def a_plus_b_expression(
+ a: vaex.expression.Expression, b: vaex.expression.Expression
+) -> vaex.expression.Expression:
+ return a + b
+
+
+def a_plus_b_nparray(a: vaex.expression.Expression, b: vaex.expression.Expression) -> np.ndarray:
+ return (a + b).to_numpy()
+
+
+def a_mean(a: vaex.expression.Expression) -> float:
+ return a.mean()
+
+
+def b_mean(b: vaex.expression.Expression) -> float:
+ return b.mean()
+
+
+def ab_as_df(
+ a: vaex.expression.Expression, b: vaex.expression.Expression
+) -> vaex.dataframe.DataFrame:
+ return vaex.from_pandas(pd.DataFrame({"a_in_df": a.to_numpy(), "b_in_df": b.to_numpy()}))
diff --git a/plugin_tests/h_vaex/test_h_vaex.py b/plugin_tests/h_vaex/test_h_vaex.py
new file mode 100644
index 000000000..e0fcb6ac1
--- /dev/null
+++ b/plugin_tests/h_vaex/test_h_vaex.py
@@ -0,0 +1,49 @@
+import numpy as np
+import vaex
+
+from hamilton import base, driver
+from hamilton.plugins import h_vaex, vaex_extensions # noqa F401
+
+from .resources import functions
+
+
+def test_vaex_column_from_expression():
+ adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult())
+ dr = driver.Driver({}, functions, adapter=adapter)
+ result_df = dr.execute(["a", "b", "a_plus_b_expression"])
+ assert isinstance(result_df, vaex.dataframe.DataFrame)
+ np.testing.assert_allclose(result_df["a_plus_b_expression"].to_numpy(), [3, 5, 7])
+
+
+def test_vaex_column_from_nparray():
+ adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult())
+ dr = driver.Driver({}, functions, adapter=adapter)
+ result_df = dr.execute(["a", "b", "a_plus_b_nparray"])
+ assert isinstance(result_df, vaex.dataframe.DataFrame)
+ np.testing.assert_allclose(result_df["a_plus_b_nparray"].to_numpy(), [3, 5, 7])
+
+
+def test_vaex_scalar_among_columns():
+ adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult())
+ dr = driver.Driver({}, functions, adapter=adapter)
+ result_df = dr.execute(["a", "b", "a_mean"])
+ assert isinstance(result_df, vaex.dataframe.DataFrame)
+ np.testing.assert_allclose(result_df["a_mean"].to_numpy(), [2, 2, 2])
+
+
+def test_vaex_only_scalars():
+ adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult())
+ dr = driver.Driver({}, functions, adapter=adapter)
+ result_df = dr.execute(["a_mean", "b_mean"])
+ assert isinstance(result_df, vaex.dataframe.DataFrame)
+ np.testing.assert_allclose(result_df["a_mean"].to_numpy(), [2])
+ np.testing.assert_allclose(result_df["b_mean"].to_numpy(), [3])
+
+
+def test_vaex_df_among_columns():
+ adapter = base.SimplePythonGraphAdapter(result_builder=h_vaex.VaexDataFrameResult())
+ dr = driver.Driver({}, functions, adapter=adapter)
+ result_df = dr.execute(["a", "b", "ab_as_df"])
+ assert isinstance(result_df, vaex.dataframe.DataFrame)
+ np.testing.assert_allclose(result_df["a_in_df"].to_numpy(), result_df["a"].to_numpy())
+ np.testing.assert_allclose(result_df["b_in_df"].to_numpy(), result_df["b"].to_numpy())
diff --git a/setup.py b/setup.py
index d11647964..c12b375bb 100644
--- a/setup.py
+++ b/setup.py
@@ -84,6 +84,10 @@ def load_requirements():
"pandera": ["pandera"],
"tqdm": ["tqdm"],
"datadog": ["ddtrace"],
+ "vaex": [
+ "pydantic<2.0", # because of https://github.com/vaexio/vaex/issues/2384
+ "vaex",
+ ],
"experiments": [
"fastapi",
"fastui",