diff --git a/README.md b/README.md index d95d00f..557eb24 100644 --- a/README.md +++ b/README.md @@ -20,4 +20,4 @@ pip install git+https://github.com/lemma-osu/gee-knn-python@main - numpy - pydantic - scikit-learn -- scikit-learn-knn-regression @ git+https://github.com/lemma-osu/scikit-learn-knn-regression@main +- sknnr @ git+https://github.com/lemma-osu/scikit-learn-knn-regression@main diff --git a/pyproject.toml b/pyproject.toml index 580e1ae..5c6518a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,6 @@ name = "gee-knn-python" dynamic = ["version"] description = "Python based nearest neighbor mapping in GEE" readme = "README.md" -license = "" requires-python = ">=3.9" authors = [{ name = "Matt Gregory", email = "matt.gregory@oregonstate.edu" }] dependencies = [ @@ -16,7 +15,7 @@ dependencies = [ "numpy", "pydantic", "scikit-learn", - "scikit-learn-knn-regression @ git+https://github.com/lemma-osu/scikit-learn-knn-regression", + "sknnr", ] [project.urls] diff --git a/src/geeknn/_base.py b/src/geeknn/_base.py index b8152bb..8f2bcc3 100644 --- a/src/geeknn/_base.py +++ b/src/geeknn/_base.py @@ -12,7 +12,13 @@ from pydantic import BaseModel from sklearn.base import TransformerMixin -from .utils import crosswalk_to_ids, filter_neighbors, get_k_neighbors, scores_to_fc +from .utils import ( + Colocation, + crosswalk_to_ids, + filter_neighbors, + get_k_neighbors, + scores_to_fc, +) class Geometry(BaseModel): @@ -159,9 +165,21 @@ def _(self, X_image: ee.Image, mode: str = "CLASSIFICATION"): return self._predict_image(X_image, mode=mode) @predict.register - def _(self, fc: ee.FeatureCollection, colocation_obj=None): + def _( + self, + fc: ee.FeatureCollection, + colocation_obj: Optional[Colocation] = None, # noqa: UP007 + num_threads: int = -1, + chunk_size: int = 500, + ): ids = fc.aggregate_array(self.id_field) - return self._predict_fc(fc, ids, colocation_obj=colocation_obj) + return self._predict_fc( + fc, + ids, + colocation_obj=colocation_obj, + num_threads=num_threads, + chunk_size=chunk_size, + ) def _predict_image(self, X_image: ee.Image, mode: str = "CLASSIFICATION"): """Predict the nearest neighbors for the given covariate image.""" @@ -179,9 +197,31 @@ def get_neighbor_band_name(i): .arrayFlatten([band_names]) ) - def _predict_fc(self, fc: ee.FeatureCollection, ids: NDArray, colocation_obj=None): + def _predict_fc( + self, + fc: ee.FeatureCollection, + ids: NDArray, + colocation_obj: Optional[Colocation] = None, # noqa: UP007 + num_threads: int = -1, + chunk_size: int = 500, + ): """Predict the nearest neighbors for the given covariate feature collection.""" - neighbor_fc = fc.classify(classifier=self.clf, outputName="neighbors") + + def _predict_batch(fc): + """Run predition for a batch of features and return the results + as a list.""" + return ( + ee.FeatureCollection(fc) + .classify(classifier=self.clf, outputName="neighbors") + .toList(fc.size()) + ) + + size = fc.size().getInfo() + chunks = [fc.toList(chunk_size, i) for i in range(0, size, chunk_size)] + + with Parallel(n_jobs=num_threads, backend="threading") as p: + chunk_data = p(delayed(_predict_batch)(chunk) for chunk in chunks) + neighbor_fc = ee.FeatureCollection(ee.List(chunk_data).flatten()) neighbor_fc = crosswalk_to_ids(neighbor_fc, ids, self.id_field) if colocation_obj is not None: neighbor_fc = filter_neighbors(neighbor_fc, colocation_obj, self.id_field)