diff --git a/examples/data-filter-extension-categorical.ipynb b/examples/data-filter-extension-categorical.ipynb new file mode 100644 index 00000000..1e6d7be1 --- /dev/null +++ b/examples/data-filter-extension-categorical.ipynb @@ -0,0 +1,257 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5a203c06-68a6-4335-9037-aef706980245", + "metadata": {}, + "outputs": [], + "source": [ + "# /// script\n", + "# requires-python = \">=3.12\"\n", + "# dependencies = [\n", + "# \"geodatasets\",\n", + "# \"geopandas\",\n", + "# \"ipywidgets\",\n", + "# \"lonboard\",\n", + "# \"numpy\",\n", + "# \"palettable\",\n", + "# \"pandas\",\n", + "# \"shapely\",\n", + "# ]\n", + "# ///" + ] + }, + { + "cell_type": "markdown", + "id": "ad4207fc-7b6d-4301-b79e-34ff1d961994", + "metadata": {}, + "source": [ + "# Categorical Filtering with the DataFilterExtension\n", + "\n", + "The `DataFilterExtension` adds GPU-based data filtering functionalities to layers, allowing the layer to show/hide objects based on user-defined properties." + ] + }, + { + "cell_type": "markdown", + "id": "0eba4205-2605-4922-be6a-ee71a4d8e389", + "metadata": {}, + "source": [ + "## Dependencies\n", + "\n", + "Install [`uv`](https://docs.astral.sh/uv) and then launch this notebook with:\n", + "\n", + "```\n", + "uvx juv run examples/data-filter-extension-categorical.ipynb\n", + "```\n", + "\n", + "(The `uvx` command is included when installing `uv`)." + ] + }, + { + "cell_type": "markdown", + "id": "a4a41490-4c62-4ed1-90f9-5126fa3e532f", + "metadata": {}, + "source": [ + "## Categorical Filtering\n", + "In this example the `DataFilterExtension` will be used to filter the display the home sales dataset from geodatasets based on the number of bedrooms, bathrooms.\n", + "\n", + "To demonstrate, we'll create:\n", + "\n", + "1. A Geopandas GeoDataFrame of home sales data.\n", + "2. A Lonboard ScatterPlotLayer from the GeoDataFrame that has a DataFilterExtension set up for categorical filtering.\n", + "3. Some IPyWidgets linked to the ScatterPlotLayer's `filter_categories` property to allow us to interactively filter the points on the map." + ] + }, + { + "cell_type": "markdown", + "id": "799e41b8-e75a-453c-a9ac-59ed6e254cd7", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7965b2a-4e93-438f-81ea-cf39917fb654", + "metadata": {}, + "outputs": [], + "source": [ + "import geodatasets\n", + "import geopandas\n", + "import traitlets\n", + "from ipywidgets import Button, HBox, SelectMultiple, VBox\n", + "from palettable.colorbrewer.diverging import RdYlGn_11\n", + "\n", + "import lonboard\n", + "from lonboard import Map, ScatterplotLayer\n", + "from lonboard.layer_extension import DataFilterExtension" + ] + }, + { + "cell_type": "markdown", + "id": "6a703ab0-6b15-4e6e-b45b-ed915fa4d670", + "metadata": {}, + "source": [ + "### Reading the Home Sales Data\n", + "\n", + "This example makes use the geodatasets python package to access some spatial data easily.\n", + "\n", + "Calling geodatasets.get_path() will download data the specified data to the machine and return the path to the downloaded file. If the file has already been downloaded it will simply return the path to the file. [See downloading and caching](https://geodatasets.readthedocs.io/en/latest/introduction.html#downloading-and-caching) for further details." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78c0a363-83dd-4dee-8ac6-10ff67e8e817", + "metadata": {}, + "outputs": [], + "source": [ + "home_sales_df = geopandas.read_file(geodatasets.get_path(\"geoda.home_sales\"))[\n", + " [\"price\", \"bedrooms\", \"bathrooms\", \"geometry\"]\n", + "]\n", + "home_sales_df" + ] + }, + { + "cell_type": "markdown", + "id": "c295793b-fc34-4e34-a2a0-6f941d7bf8fd", + "metadata": {}, + "source": [ + "### Create the `ScatterplotLayer` with a `DataFilterExtension` extension for categorical filtering\n", + "\n", + "The `DataFilterExtension` will be cretated with `filter_size=None` to indicate we do not want to use a range filter, and `category_size=2` to indicate we want to use two different categories from the data to filter the data with explicit values.\n", + "\n", + "The points in the layer will be symbolized based on a continuous colormap of the price. Lower priced homes will be red, and higher priced homes will be green. We'll throw out the upper and lower 5% of value from the color map so the upper and lower outliers do not influence the colormap." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14a311f4-7344-496d-a9b3-7ec8eaee5cef", + "metadata": {}, + "outputs": [], + "source": [ + "min_bound = home_sales_df[\"price\"].quantile(0.05)\n", + "max_bound = home_sales_df[\"price\"].quantile(0.95)\n", + "price = home_sales_df[\"price\"]\n", + "normalized_price = (price - min_bound) / (max_bound - min_bound)\n", + "\n", + "home_sale_layer = ScatterplotLayer.from_geopandas(\n", + " home_sales_df,\n", + " get_fill_color=lonboard.colormap.apply_continuous_cmap(normalized_price, RdYlGn_11),\n", + " radius_min_pixels=5,\n", + " extensions=[\n", + " DataFilterExtension(filter_size=None, category_size=2),\n", + " ],\n", + " get_filter_category=home_sales_df[[\"bedrooms\", \"bathrooms\"]].values,\n", + " filter_categories=[[], []],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "83854bcd-a17c-4d2a-91b7-c6626dd67ede", + "metadata": {}, + "source": [ + "### Create the iPyWidgets to interact with the `DataFilterExtension`\n", + "\n", + "Since we want to only display the points which are tied to specific number of bedrooms and bathrooms we'll:\n", + "\n", + "1. Create two ipywidgets `SelectMultiple` widgets which will hold the different numbers of bedrooms and bathrooms in the home sales data.\n", + "2. Create two ipywidgets `Button` widgets which will clear the selected values for the bedrooms and bathrooms.\n", + "4. Observe the changes made to the `SelectMultiple` widgets to update the layer's `filter_categories` property.\n", + "\n", + "This will enable us to select one or more number of bedrooms or bathrooms and have the map instantly react to disply only the data that matches the selections. If a select widget does not have a selection, all the values from that selector will be used." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ea817c2-7a6d-4b43-99f4-ae0f638151e9", + "metadata": {}, + "outputs": [], + "source": [ + "unique_bedrooms_values = list(home_sales_df[\"bedrooms\"].sort_values().unique())\n", + "unique_bathrooms_values = list(home_sales_df[\"bathrooms\"].sort_values().unique())\n", + "\n", + "bedrooms_select = SelectMultiple(description=\"Bedrooms\", options=unique_bedrooms_values)\n", + "bathrooms_select = SelectMultiple(\n", + " description=\"Bathrooms\",\n", + " options=unique_bathrooms_values,\n", + ")\n", + "\n", + "\n", + "def on_select_change(_: traitlets.utils.bunch.Bunch = None) -> None:\n", + " \"\"\"Set the layer's filter_categories property based on widget selections.\"\"\"\n", + " bedrooms = bedrooms_select.value\n", + " bathrooms = bathrooms_select.value\n", + " if len(bedrooms) == 0:\n", + " bedrooms = unique_bedrooms_values\n", + " if len(bathrooms) == 0:\n", + " bathrooms = unique_bathrooms_values\n", + " home_sale_layer.filter_categories = [bedrooms, bathrooms]\n", + "\n", + "\n", + "bedrooms_select.observe(on_select_change, \"value\")\n", + "bathrooms_select.observe(on_select_change, \"value\")\n", + "\n", + "clear_bedrooms_button = Button(description=\"Clear Bedrooms\")\n", + "\n", + "\n", + "def clear_bedrooms(_: Button) -> None:\n", + " bedrooms_select.value = []\n", + "\n", + "\n", + "clear_bedrooms_button.on_click(clear_bedrooms)\n", + "\n", + "clear_bathrooms_button = Button(description=\"Clear Bathrooms\")\n", + "\n", + "\n", + "def clear_bathrooms(_: Button) -> None:\n", + " bathrooms_select.value = []\n", + "\n", + "\n", + "clear_bathrooms_button.on_click(clear_bathrooms)\n", + "\n", + "home_sale_map = Map(\n", + " layers=[home_sale_layer],\n", + " basemap=lonboard.basemap.MaplibreBasemap(),\n", + ")\n", + "on_select_change() # fire the function once to initially set the layer's filter_categories, and display all points\n", + "\n", + "display(home_sale_map)\n", + "display(\n", + " HBox(\n", + " [\n", + " VBox([bedrooms_select, clear_bedrooms_button]),\n", + " VBox([bathrooms_select, clear_bathrooms_button]),\n", + " ],\n", + " ),\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lonboard_category_filter", + "language": "python", + "name": "lonboard_category_filter" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/lonboard/layer_extension.py b/lonboard/layer_extension.py index 306a3c94..ac5bd02a 100644 --- a/lonboard/layer_extension.py +++ b/lonboard/layer_extension.py @@ -5,6 +5,7 @@ from lonboard._base import BaseExtension from lonboard.traits import ( DashArrayAccessor, + FilterCategoryAccessor, FilterValueAccessor, FloatAccessor, PointAccessor, @@ -353,15 +354,18 @@ class DataFilterExtension(BaseExtension): "filter_transform_size": t.Bool(default_value=True).tag(sync=True), "filter_transform_color": t.Bool(default_value=True).tag(sync=True), "get_filter_value": FilterValueAccessor(default_value=None, allow_none=True), - "get_filter_category": FilterValueAccessor(default_value=None, allow_none=True), + "get_filter_category": FilterCategoryAccessor( + default_value=None, + allow_none=True, + ), } - filter_size = t.Int(None, min=1, max=4, allow_none=True).tag(sync=True) + filter_size = t.Int(1, min=1, max=4, allow_none=True).tag(sync=True) """The size of the filter (number of columns to filter by). The data filter can show/hide data based on 1-4 numeric properties of each object. - - Type: `int`. This is required if using range-based filtering. + - Type: `int`, optional. This is required if using range-based filtering. - Default 1. """ @@ -370,8 +374,8 @@ class DataFilterExtension(BaseExtension): The category filter can show/hide data based on 1-4 properties of each object. - - Type: `int`. This is required if using category-based filtering. - - Default 0. + - Type: `int`, optional. This is required if using category-based filtering. + - Default None. """ diff --git a/lonboard/traits/__init__.py b/lonboard/traits/__init__.py index 3265eebf..7cd8c4b4 100644 --- a/lonboard/traits/__init__.py +++ b/lonboard/traits/__init__.py @@ -7,7 +7,7 @@ from ._a5 import A5Accessor from ._base import FixedErrorTraitType, VariableLengthTuple from ._color import ColorAccessor -from ._extensions import DashArrayAccessor, FilterValueAccessor +from ._extensions import DashArrayAccessor, FilterCategoryAccessor, FilterValueAccessor from ._float import FloatAccessor from ._h3 import H3Accessor from ._map import BasemapUrl, MapHeightTrait, ViewStateTrait @@ -23,6 +23,7 @@ "BasemapUrl", "ColorAccessor", "DashArrayAccessor", + "FilterCategoryAccessor", "FilterValueAccessor", "FixedErrorTraitType", "FloatAccessor", diff --git a/lonboard/traits/_extensions.py b/lonboard/traits/_extensions.py index dc3b8cf7..59baaf7c 100644 --- a/lonboard/traits/_extensions.py +++ b/lonboard/traits/_extensions.py @@ -228,6 +228,178 @@ def validate( return value.rechunk(max_chunksize=obj._rows_per_chunk) +class FilterCategoryAccessor(FixedErrorTraitType): + """Validate input for `get_filter_category`. + + A trait to validate input for the `get_filter_category` accessor added by the + [`DataFilterExtension`][lonboard.layer_extension.DataFilterExtension], which can + have between 1 and 4 values per row. + + + Various input is allowed: + + - An `int` or `float`. This will be used as the value for all objects. The + `category_size` of the + [`DataFilterExtension`][lonboard.layer_extension.DataFilterExtension] instance + must be 1. + - A one-dimensional numpy `ndarray` with a numeric data type. Each value in the array will + be used as the value for the object at the same row index. The `category_size` of + the [`DataFilterExtension`][lonboard.layer_extension.DataFilterExtension] instance + must be 1. + - A two-dimensional numpy `ndarray` with a numeric data type. Each value in the array will + be used as the value for the object at the same row index. The `category_size` of + the [`DataFilterExtension`][lonboard.layer_extension.DataFilterExtension] instance + must match the size of the second dimension of the array. + - A pandas `Series` with a numeric data type. Each value in the array will be used as + the value for the object at the same row index. The `category_size` of the + [`DataFilterExtension`][lonboard.layer_extension.DataFilterExtension] instance + must be 1. + - A pyarrow [`FloatArray`][pyarrow.FloatArray], [`DoubleArray`][pyarrow.DoubleArray] + or [`ChunkedArray`][pyarrow.ChunkedArray] containing either a `FloatArray` or + `DoubleArray`. Each value in the array will be used as the value for the object at + the same row index. The `category_size` of the + [`DataFilterExtension`][lonboard.layer_extension.DataFilterExtension] instance + must be 1. + + Alternatively, you can pass any corresponding Arrow data structure from a library + that implements the [Arrow PyCapsule + Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html). + - A pyarrow [`FixedSizeListArray`][pyarrow.FixedSizeListArray] or + [`ChunkedArray`][pyarrow.ChunkedArray] containing `FixedSizeListArray`s. The `category_size` of + the [`DataFilterExtension`][lonboard.layer_extension.DataFilterExtension] instance + must match the list size. + + Alternatively, you can pass any corresponding Arrow data structure from a library + that implements the [Arrow PyCapsule + Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html). + """ + + default_value = None + info_text = "a value or numpy ndarray or Arrow array representing an array of data" + + def __init__( + self: TraitType, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.tag(sync=True, **ACCESSOR_SERIALIZATION) + + def _pandas_to_numpy( + self, + obj: BaseArrowLayer, + value: Any, + category_size: int, + ) -> np.ndarray: + # Assert that category_size == 1 for a pandas series. + # Pandas series can technically contain Python list objects inside them, but + # for simplicity we disallow that. + if category_size != 1: + self.error(obj, value, info="category_size==1 with pandas Series") + + # Cast pandas Series to numpy ndarray + return np.asarray(value) + + def _numpy_to_arrow( + self, + obj: BaseArrowLayer, + value: Any, + category_size: int, + ) -> ChunkedArray: + if len(value.shape) == 1: + if category_size != 1: + self.error(obj, value, info="category_size==1 with 1-D numpy array") + array = fixed_size_list_array(value, category_size) + return ChunkedArray(array) + + if len(value.shape) != 2: + self.error(obj, value, info="1-D or 2-D numpy array") + + if value.shape[1] != category_size: + self.error( + obj, + value, + info=( + f"category_size ({category_size}) to match 2nd dimension of numpy array" + ), + ) + array = fixed_size_list_array(value.ravel("C"), category_size) + return ChunkedArray([array]) + + def validate( + self, + obj: BaseArrowLayer, + value: Any, + ) -> str | float | tuple | list | ChunkedArray: + # Find the data filter extension in the attributes of the parent object so we + # can validate against the filter size. + data_filter_extension = [ + ext + for ext in obj.extensions + if ext._extension_type == "data-filter" # type: ignore + ] + assert len(data_filter_extension) == 1 + category_size = data_filter_extension[0].category_size # type: ignore + + if isinstance(value, (int, float, str)): + if category_size != 1: + self.error(obj, value, info="category_size==1 with scalar value") + return value + + if isinstance(value, (tuple, list)): + if category_size != len(value): + self.error( + obj, + value, + info=f"category_size ({category_size}) to match length of tuple/list", + ) + return value + + # pandas Series + if ( + value.__class__.__module__.startswith("pandas") + and value.__class__.__name__ == "Series" + ): + value = self._pandas_to_numpy(obj, value, category_size) + + if isinstance(value, np.ndarray): + value = self._numpy_to_arrow(obj, value, category_size) + elif hasattr(value, "__arrow_c_array__"): + value = ChunkedArray([Array.from_arrow(value)]) + elif hasattr(value, "__arrow_c_stream__"): + value = ChunkedArray.from_arrow(value) + else: + self.error(obj, value) + + assert isinstance(value, ChunkedArray) + + # Allowed inputs are either a FixedSizeListArray or array. + if not DataType.is_fixed_size_list(value.type): + if category_size != 1: + self.error( + obj, + value, + info="category_size==1 with non-FixedSizeList type arrow array", + ) + + return value + + # We have a FixedSizeListArray + if category_size != value.type.list_size: + self.error( + obj, + value, + info=( + f"category_size ({category_size}) to match list size of " + "FixedSizeList arrow array" + ), + ) + + value_type = value.type.value_type + assert value_type is not None + return value.rechunk(max_chunksize=obj._rows_per_chunk) + + class DashArrayAccessor(FixedErrorTraitType): """A trait to validate input for a deck.gl dash accessor. diff --git a/mkdocs.yml b/mkdocs.yml index 4cfad63b..f00d19c6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -47,6 +47,7 @@ nav: - examples/kontur_pop.ipynb - examples/migration.ipynb - examples/data-filter-extension.ipynb + - examples/data-filter-extension-categorical.ipynb - examples/column-layer.ipynb - examples/interleaved-labels.ipynb - examples/linked-maps.ipynb diff --git a/src/model/extension.ts b/src/model/extension.ts index ede75be1..9cfe3cca 100644 --- a/src/model/extension.ts +++ b/src/model/extension.ts @@ -147,19 +147,15 @@ export class DataFilterExtension extends BaseExtensionModel { } extensionInstance(): _DataFilterExtension | null { - if (isDefined(this.filterSize)) { - const props = { - ...(isDefined(this.filterSize) ? { filterSize: this.filterSize } : {}), - }; - // console.log("ext props", props); - return new _DataFilterExtension(props); - } else if (isDefined(this.categorySize)) { + if (isDefined(this.filterSize) || isDefined(this.categorySize)) { const props = { + ...(isDefined(this.filterSize) + ? { filterSize: this.filterSize != null ? this.filterSize : 0 } + : {}), ...(isDefined(this.categorySize) - ? { categorySize: this.categorySize } + ? { categorySize: this.categorySize != null ? this.categorySize : 0 } : {}), }; - // console.log("ext props", props); return new _DataFilterExtension(props); } else { return null; diff --git a/tests/traits/__init__.py b/tests/traits/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/traits/test_filter_extension.py b/tests/traits/test_filter_extension.py new file mode 100644 index 00000000..ec16e06a --- /dev/null +++ b/tests/traits/test_filter_extension.py @@ -0,0 +1,331 @@ +import arro3 +import geopandas as gpd +import pytest +from shapely.geometry import Point +from traitlets import TraitError + +import lonboard +from lonboard.layer_extension import DataFilterExtension + + +@pytest.fixture +def dfe_test_df() -> gpd.GeoDataFrame: + """GeoDataframe for testing DataFilterExtension.""" + d = { + "int_col": [0, 1, 2, 3, 4, 5], + "float_col": [0.0, 1.5, 0.0, 1.5, 0.0, 1.5], + "str_col": ["even", "odd", "even", "odd", "even", "odd"], + "geometry": [ + Point(0, 0), + Point(1, 1), + Point(2, 2), + Point(3, 3), + Point(4, 4), + Point(5, 5), + ], + } + return gpd.GeoDataFrame(d, crs="EPSG:4326") + + +def test_dfe_no_args_no_get_filter_value(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE without args, no get_filter_value + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(), + ], + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size == 1 + assert dfe.category_size is None + assert layer.get_filter_value is None + assert layer.get_filter_category is None + + +def test_dfe_no_args_and_int_get_filter_value(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE without args and int get_filter_value + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(), + ], + get_filter_value=dfe_test_df["int_col"], + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size == 1 + assert dfe.category_size is None + assert isinstance(layer.get_filter_value, arro3.core.ChunkedArray) + assert layer.get_filter_category is None + + +def test_dfe_no_args_and_float_get_filter_value(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE without args and float get_filter_value + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(), + ], + get_filter_value=dfe_test_df["float_col"], + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size == 1 + assert dfe.category_size is None + assert isinstance(layer.get_filter_value, arro3.core.ChunkedArray) + assert layer.get_filter_category is None + + +def test_dfe_filter_size_no_get_filter_value(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size, no get_filter_value + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=1), + ], + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size == 1 + assert dfe.category_size is None + assert layer.get_filter_value is None + assert layer.get_filter_category is None + + +def test_dfe_filter_size_and_get_filter_value(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size and get_filter_value + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=1), + ], + get_filter_value=dfe_test_df["int_col"], + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size == 1 + assert dfe.category_size is None + assert isinstance(layer.get_filter_value, arro3.core.ChunkedArray) + assert layer.get_filter_category is None + + +def test_dfe_filter_size2_no_get_filter_value(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size 2, no get_filter_value + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=2), + ], + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size == 2 + assert dfe.category_size is None + assert layer.get_filter_value is None + assert layer.get_filter_category is None + + +def test_dfe_filter_size2_and_get_filter_value(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size 2 and get_filter_value + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=2), + ], + get_filter_value=dfe_test_df[["int_col", "float_col"]].values, + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size == 2 + assert dfe.category_size is None + assert isinstance(layer.get_filter_value, arro3.core.ChunkedArray) + assert layer.get_filter_category is None + + +def test_dfe_cat_no_get_filter_cat(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size None category_size=1, no get_filter_category + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=None, category_size=1), + ], + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size is None + assert dfe.category_size == 1 + assert layer.get_filter_value is None + assert layer.get_filter_category is None + + +def test_dfe_cat_and_int_get_filter_cat(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size None category_size=1 and get_filter_category + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=None, category_size=1), + ], + get_filter_category=dfe_test_df["int_col"], + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size is None + assert dfe.category_size == 1 + assert layer.get_filter_value is None + assert isinstance(layer.get_filter_category, arro3.core.ChunkedArray) + + +def test_dfe_cat_and_float_get_filter_cat(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size None category_size=1 and get_filter_category + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=None, category_size=1), + ], + get_filter_category=dfe_test_df["float_col"], + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size is None + assert dfe.category_size == 1 + assert layer.get_filter_value is None + assert isinstance(layer.get_filter_category, arro3.core.ChunkedArray) + + +def test_dfe_cat2_get_filter_cat(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size None category_size=2 and get_filter_category + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=None, category_size=2), + ], + get_filter_category=dfe_test_df[["int_col", "float_col"]].values, + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size is None + assert dfe.category_size == 2 + assert layer.get_filter_value is None + assert isinstance(layer.get_filter_category, arro3.core.ChunkedArray) + + +def test_dfe_value_and_cat_no_get_filter_value_or_category( + dfe_test_df: gpd.GeoDataFrame, +): + ## Test DFE with filter_size=1 category_size=1 and no get_filter_value or get_filter_category + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=1, category_size=1), + ], + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size == 1 + assert dfe.category_size == 1 + assert layer.get_filter_value is None + assert layer.get_filter_category is None + + +def test_dfe_value_and_cat_and_get_filter_value_category(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size=1 category_size=1 and get_filter_value/get_filter_category + layer = lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=1, category_size=1), + ], + get_filter_value=dfe_test_df["int_col"], + get_filter_category=dfe_test_df["float_col"], + ) + assert len(layer.extensions) == 1 + dfe = layer.extensions[0] + assert isinstance(dfe, DataFilterExtension) + assert dfe.filter_size == 1 + assert dfe.category_size == 1 + assert isinstance(layer.get_filter_value, arro3.core.ChunkedArray) + assert isinstance(layer.get_filter_category, arro3.core.ChunkedArray) + + +def test_dfe_filter_size_none_with_filter_value(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size=None category_size=1 and get_filter_value provided raises + with pytest.raises(TraitError): + lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=None, category_size=1), + ], + get_filter_value=dfe_test_df["float_col"], + ) + + +def test_dfe_category_size_none_with_filter_category(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size=1 category_size=None and get_filter_category provided raises + with pytest.raises(TraitError): + lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=1, category_size=None), + ], + get_filter_category=dfe_test_df["float_col"], + ) + + +def test_dfe_wrong_get_filter_value_size(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size=2 and get_filter_value with 1-D array raises + with pytest.raises(TraitError): + lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=2, category_size=None), + ], + get_filter_value=dfe_test_df["float_col"], + ) + + +def test_dfe_wrong_get_filter_value_size2(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with filter_size=1 and get_filter_value with 2-D array raises + with pytest.raises(TraitError): + lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=1, category_size=None), + ], + get_filter_value=dfe_test_df[["int_col", "float_col"]].values, + ) + + +def test_dfe_wrong_get_filter_category_size(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with category_size=2 and get_filter_category with 1-D array raises + with pytest.raises(TraitError): + lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=None, category_size=2), + ], + get_filter_category=dfe_test_df["float_col"], + ) + + +def test_dfe_wrong_get_filter_category_size2(dfe_test_df: gpd.GeoDataFrame): + ## Test DFE with category_size=1 and get_filter_category with 2-D array raises + with pytest.raises(TraitError): + lonboard.ScatterplotLayer.from_geopandas( + dfe_test_df, + extensions=[ + DataFilterExtension(filter_size=None, category_size=1), + ], + get_filter_category=dfe_test_df[["int_col", "float_col"]].values, + )