From 4185a463897333d460d7a3f6c79f791a1667512b Mon Sep 17 00:00:00 2001 From: ivanzvonkov Date: Thu, 7 Jul 2022 13:15:55 -0400 Subject: [PATCH] Notebook accessible on colab --- README.md | 2 +- demo.ipynb | 1004 ++++++++++++++++------------------------------------ 2 files changed, 303 insertions(+), 703 deletions(-) diff --git a/README.md b/README.md index d62794d..75b1493 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ conda install 'fiona>=1.5' 'rasterio>=1.2.6' pip install cropharvest ``` -### Getting started +### Getting started [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nasaharvest/cropharvest/blob/main/demo.ipynb) See the [`demo.ipynb`](https://github.com/nasaharvest/cropharvest/blob/main/demo.ipynb) notebook for an example on how to download the data from [Zenodo](https://zenodo.org/record/5828893) and train a random forest against this data. For more examples of models trained against this dataset, see the [benchmarks](https://github.com/nasaharvest/cropharvest/blob/main/benchmarks). diff --git a/demo.ipynb b/demo.ipynb index 1fd23bc..dfeb2ac 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -1,717 +1,317 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "325c9810", - "metadata": {}, - "source": [ - "# CropHarvest Demo\n", - "\n", - "**Authors**: Gabriel Tseng, Ivan Zvonkov\n", - "\n", - "**Description**: This notebook demonstrates the capabilities of the CropHarvest package by training and testing a model on a subset of the data and then running inference using the trained model." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "e62c6553", - "metadata": {}, - "outputs": [], - "source": [ - "from cropharvest.datasets import CropHarvest\n", - "from cropharvest.inference import Inference\n", - "from pathlib import Path\n", - "from sklearn.ensemble import RandomForestClassifier\n", - "\n", - "import requests\n", - "import tempfile\n", - "\n", - "DATA_DIR = \"data\"" - ] - }, - { - "cell_type": "markdown", - "id": "7703facb", - "metadata": {}, - "source": [ - "## Load datasets" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "d8ef538e", - "metadata": { - "scrolled": true - }, - "outputs": [ + "cells": [ { - "data": { - "text/plain": [ - "[CropHarvestEval(Kenya_1_maize, Kenya_maize),\n", - " CropHarvestEval(Brazil_0_coffee, Brazil_coffee),\n", - " CropHarvestEval(Togo_crop, togo-eval)]" + "cell_type": "markdown", + "id": "325c9810", + "metadata": { + "id": "325c9810" + }, + "source": [ + "# CropHarvest Demo\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nasaharvest/cropharvest/blob/main/demo.ipynb)\n", + "\n", + "**Authors**: Gabriel Tseng, Ivan Zvonkov\n", + "\n", + "**Description**: This notebook demonstrates the capabilities of the CropHarvest package by training and testing a model on a subset of the data and then running inference using the trained model." ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "evaluation_datasets = CropHarvest.create_benchmark_datasets(DATA_DIR)\n", - "evaluation_datasets" - ] - }, - { - "cell_type": "markdown", - "id": "e79554ac", - "metadata": {}, - "source": [ - "## Split Togo data into X and y" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "8f326f27", - "metadata": {}, - "outputs": [ + }, + { + "cell_type": "code", + "source": [ + "# Download from PyPI\n", + "!pip install cropharvest -q\n", + "\n", + "# Download from TestPyPI\n", + "#!pip install -i https://test.pypi.org/simple/ cropharvest --extra-index-url https://pypi.python.org/simple -q" + ], + "metadata": { + "id": "beyzvBH4nrU-" + }, + "id": "beyzvBH4nrU-", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip freeze | grep cropharvest " + ], + "metadata": { + "id": "0S8zzJWwn155" + }, + "id": "0S8zzJWwn155", + "execution_count": null, + "outputs": [] + }, { - "data": { - "text/plain": [ - "((1290, 216), (1290,))" + "cell_type": "code", + "execution_count": null, + "id": "e62c6553", + "metadata": { + "id": "e62c6553" + }, + "outputs": [], + "source": [ + "from cropharvest.datasets import CropHarvest\n", + "from cropharvest.inference import Inference\n", + "from pathlib import Path\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "\n", + "import requests\n", + "import tempfile\n", + "\n", + "DATA_DIR = \"data\"\n", + "\n", + "!mkdir $DATA_DIR" ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "togo_dataset = evaluation_datasets[-1]\n", - "X, y = togo_dataset.as_array(flatten_x=True)\n", - "X.shape, y.shape" - ] - }, - { - "cell_type": "markdown", - "id": "147e7c93", - "metadata": {}, - "source": [ - "## Train a Random Forest model on the Togo dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "8351896b", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "RandomForestClassifier(random_state=0)" + "cell_type": "markdown", + "id": "7703facb", + "metadata": { + "id": "7703facb" + }, + "source": [ + "## Load datasets" ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model = RandomForestClassifier(random_state=0)\n", - "model.fit(X, y)" - ] - }, - { - "cell_type": "markdown", - "id": "c03e4273", - "metadata": {}, - "source": [ - "## Make predictions on Togo test set" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "0f47915e", - "metadata": {}, - "outputs": [ + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "For the Random Forest classifier, {'auc_roc': 0.8954952830188679, 'f1_score': 0.7401574803149605, 'iou': 0.5875, 'num_samples': 306}, \n" - ] - } - ], - "source": [ - "test_preds, test_instances = [], []\n", - "for _, test_instance in togo_dataset.test_data(flatten_x=True):\n", - " test_preds.append(model.predict_proba(test_instance.x)[:, 1])\n", - " test_instances.append(test_instance)\n", - " \n", - "print(\n", - " f\"For the Random Forest classifier, \"\n", - " f\"{test_instances[0].evaluate_predictions(test_preds[0])}, \"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "b9b06817", - "metadata": {}, - "source": [ - "## Get test file for inference" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "8e5e60cd", - "metadata": {}, - "outputs": [], - "source": [ - "test_file = \"98-togo_2019-02-06_2020-02-01.tif\"\n", - "\n", - "temp_dir = tempfile.gettempdir()\n", - "p = Path(temp_dir) / test_file\n", - "response = requests.get(\n", - " f\"https://github.com/nasaharvest/cropharvest/blob/main/test/cropharvest/{test_file}?raw=true\", \n", - ")\n", - "with p.open(\"wb\") as f:\n", - " f.write(response.content)" - ] - }, - { - "cell_type": "markdown", - "id": "e61d7ae0", - "metadata": {}, - "source": [ - "## Run inference" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "2402cbb4", - "metadata": { - "scrolled": false - }, - "outputs": [ + "cell_type": "code", + "execution_count": null, + "id": "d8ef538e", + "metadata": { + "scrolled": true, + "id": "d8ef538e" + }, + "outputs": [], + "source": [ + "evaluation_datasets = CropHarvest.create_benchmark_datasets(DATA_DIR)\n", + "evaluation_datasets" + ] + }, { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.Dataset>\n",
-       "Dimensions:       (lat: 17, lon: 17)\n",
-       "Coordinates:\n",
-       "  * lat           (lat) float64 7.719 7.719 7.719 7.719 ... 7.72 7.72 7.72 7.72\n",
-       "  * lon           (lon) float64 1.422 1.422 1.422 1.422 ... 1.423 1.423 1.424\n",
-       "Data variables:\n",
-       "    prediction_0  (lat, lon) float64 0.26 0.27 0.27 0.27 ... 0.28 0.28 0.28 0.28
" - ], - "text/plain": [ - "\n", - "Dimensions: (lat: 17, lon: 17)\n", - "Coordinates:\n", - " * lat (lat) float64 7.719 7.719 7.719 7.719 ... 7.72 7.72 7.72 7.72\n", - " * lon (lon) float64 1.422 1.422 1.422 1.422 ... 1.423 1.423 1.424\n", - "Data variables:\n", - " prediction_0 (lat, lon) float64 0.26 0.27 0.27 0.27 ... 0.28 0.28 0.28 0.28" + "cell_type": "markdown", + "id": "e79554ac", + "metadata": { + "id": "e79554ac" + }, + "source": [ + "## Split Togo data into X and y" ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "preds = Inference(model=model, normalizing_dict=None).run(p)\n", - "preds" - ] - }, - { - "cell_type": "markdown", - "id": "98bf4ee3", - "metadata": {}, - "source": [ - "## [Optional] Visualize model prediction" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "33f80c45", - "metadata": {}, - "outputs": [ + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[31mgoogle-api-python-client 1.12.8 has requirement google-api-core<2dev,>=1.21.0, but you'll have google-api-core 2.3.2 which is incompatible.\u001b[0m\r\n", - "\u001b[33mYou are using pip version 10.0.1, however version 22.0.3 is available.\r\n", - "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\r\n" - ] - } - ], - "source": [ - "!pip install matplotlib -q" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "0d559099", - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "6c1f182b", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": null, + "id": "8f326f27", + "metadata": { + "id": "8f326f27" + }, + "outputs": [], + "source": [ + "togo_dataset = evaluation_datasets[-1]\n", + "X, y = togo_dataset.as_array(flatten_x=True)\n", + "\n", + "assert X.shape[0] == 1290\n", + "assert y.shape[0] == 1290\n", + "assert X.shape[1] == 216\n", + "\n", + "X.shape, y.shape" + ] + }, + { + "cell_type": "markdown", + "id": "147e7c93", + "metadata": { + "id": "147e7c93" + }, + "source": [ + "## Train a Random Forest model on the Togo dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8351896b", + "metadata": { + "id": "8351896b" + }, + "outputs": [], + "source": [ + "model = RandomForestClassifier(random_state=0)\n", + "model.fit(X, y)" + ] + }, { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAEGCAYAAACtqQjWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAb6klEQVR4nO3dfZRV9X3v8fdHFBUUFEWLCCWJQmNqGGWuMU81SjASgw81JlBNjHoXEmyXtEqF1pj05t67TMm1Ses1lPoATSyNRrhq1KgxBnoRY4EQBAWBEI3BgIpXFJYxku/9Y/9GN8czZ84Mv3NmJnxea501e/8e9vluFsyHvfc5eysiMDMzy2Gf7i7AzMx+fzhUzMwsG4eKmZll41AxM7NsHCpmZpbNvt1dQHc6fFCfGDFsv+4uw8ysV1m+6jcvRsTgan17daiMGLYfjz8wvLvLMDPrVfoMWf9Me30+/WVmZtk4VMzMLBuHipmZZeNQMTOzbBwqZmaWjUPFzMyycaiYmVk2DhUzM8vGoWJmZtk4VMzMLJuGhYqkUZJWll7bJU2rGDO91L9a0i5JgyQNk/SIpCclrZF0RWnOIEkPSVqffh6a2i+QtErSE5IelTS6UftmZmbVNSxUImJdRLRERAswBtgJLKwYM6s0ZiawKCK2AW8CV0bEccDJwOWSjkvTZgAPR8SxwMNpHWATcEpEHA98FZjTqH0zM7PqmnX6ayywMSLavQkZMAmYDxARz0fEirT8KvAUMDSNOxuYl5bnAeekcY9GxMup/THg6Jw7YGZmHWtWqEwkBUY1kvoBZwB3VukbAZwA/CQ1HRkRz6flXwNHVtnkpcD97bzXZEnLJC174aVdde+AmZl1rOGhIqkvcBZwR41hE4Al6dRXee5BFEEzLSK2V06KiACiYs6pFKFydbU3iog5EdEaEa2DD+vTqX0xM7PamnGkMh5YERFbaox5x5GMpP0oAuW2iFhQ6toiaUgaMwTYWprzfuAm4OyIeClT/WZmVqdmhMpb10qqkTQQOAW4q9Qm4GbgqYi4vmLK3cBFafmitnmShgMLgM9FxNPZqjczs7o1NFQk9QfGUfyyb2ubImlKadi5wIMRsaPU9mHgc8BppY8cfzL1XQeMk7Qe+HhaB7gWOAy4MY1f1pi9MjOz9qi4LLF3ah19QPhxwmZmndNnyPrlEdFarc/fqDczs2wcKmZmlo1DxczMsnGomJlZNg4VMzPLxqFiZmbZOFTMzCwbh4qZmWXjUDEzs2wcKmZmlo1DxczMsnGomJlZNg4VMzPLxqFiZmbZOFTMzCwbh4qZmWXjUDEzs2wcKmZmlo1DxczMsmlYqEgaJWll6bVd0rSKMdNL/asl7ZI0SNIwSY9IelLSGklXlOYMkvSQpPXp56GpXZL+UdIGSaskndiofTMzs+oaFioRsS4iWiKiBRgD7AQWVoyZVRozE1gUEduAN4ErI+I44GTgcknHpWkzgIcj4ljg4bQOMB44Nr0mA99q1L6ZmVl1zTr9NRbYGBHP1BgzCZgPEBHPR8SKtPwq8BQwNI07G5iXlucB55Ta/zUKjwGHSBqSdS/MzKymZoXKRFJgVCOpH3AGcGeVvhHACcBPUtOREfF8Wv41cGRaHgr8sjT1Od4OovL2JktaJmnZCy/t6uRumJlZLQ0PFUl9gbOAO2oMmwAsSae+ynMPogiaaRGxvXJSRAQQnaknIuZERGtEtA4+rE9nppqZWQeacaQyHlgREVtqjHnHkYyk/SgC5baIWFDq2tJ2Wiv93JrafwUMK407OrWZmVmTNCNU3rpWUo2kgcApwF2lNgE3A09FxPUVU+4GLkrLF5Xm3Q18Pn0K7GTgldJpMjMza4KGhoqk/sA4YEGpbYqkKaVh5wIPRsSOUtuHgc8Bp5U+cvzJ1HcdME7SeuDjaR3gPuDnwAbgX4CpjdgnMzNrn4rLEnun1tEHxOMPDO/uMszMepU+Q9Yvj4jWan3+Rr2ZmWXjUDEzs2wcKmZmlo1DxczMsnGomJlZNg4VMzPLxqFiZmbZOFTMzCwbh4qZmWXjUDEzs2wcKmZmlo1DxczMsnGomJlZNg4VMzPLxqFiZmbZOFTMzCwbh4qZmWXjUDEzs2wcKmZmlo1DxczMsmlYqEgaJWll6bVd0rSKMdNL/asl7ZI0KPXdImmrpNUVc0ZLWirpCUn3SBqQ2veTNC+1PyVpZqP2zczMqmtYqETEuohoiYgWYAywE1hYMWZWacxMYFFEbEvdc4Ezqmz6JmBGRByftjc9tZ8P7J/axwCXSRqRdafMzKymZp3+GgtsjIhnaoyZBMxvW4mIxcC2KuNGAovT8kPAeW1TgP6S9gUOBN4Atu9h3WZm1gnNCpWJlAKjkqR+FEcld9axrTXA2Wn5fGBYWv4esAN4HngW+HrpqKf8XpMlLZO07IWXdtW/B2Zm1qGGh4qkvsBZwB01hk0AllQLgSouAaZKWg4cTHFEAnASsAs4CngXcKWkd1dOjog5EdEaEa2DD+vTiT0xM7OO7NuE9xgPrIiILTXG1DySKYuItcDpAJJGAmemrj8DfhARvwW2SloCtAI/72rhZmbWOc04/bXbtZJKkgYCpwB31bMxSUekn/sA1wCzU9ezwGmprz9wMrC2y1WbmVmnNTRU0i/3ccCCUtsUSVNKw84FHoyIHRVz5wNLgVGSnpN0aeqaJOlpisDYDNya2v83cJCkNcB/ArdGxKpG7JeZmVWniOjuGrrN/sOHxdCrpmXb3obP/HO2bZmZ5XLM7Zdl3d6mK65aHhGt1fr8jXozM8vGoWJmZtk4VMzMLBuHipmZZeNQMTOzbBwqZmaWjUPFzMyycaiYmVk2DhUzM8vGoWJmZtk4VMzMLBuHipmZZdOM56n0WMcf+gKP+yaQZmbZ1HWkosKFkq5N68MlndTY0szMrLep9/TXjcAHKR64BfAqxfNLzMzM3lLv6a8PRMSJkn4KEBEvp2fPm5mZvaXeI5XfSuoDBICkwcDvGlaVmZn1SvWGyj8CC4EjJP0P4P8C/7NhVZmZWa9U1+mviLhN0nJgLCDgnIh4qqGVmZlZr1PzSEXSoLYXsBWYD/wbsCW11Zo7StLK0mu7pGkVY6aX+ldL2tW2XUm3SNoqaXXFnNGSlkp6QtI9kgaU+t6f+tak/gM69adhZmZ7pKPTX8uBZennC8DTwPq0vLzWxIhYFxEtEdECjAF2UpxCK4+ZVRozE1gUEdtS91zgjCqbvgmYERHHp+1NB5C0L/AdYEpEvA/4GPDbDvbPzMwyqhkqEfGuiHg38ENgQkQcHhGHAZ8CHuzE+4wFNkbEMzXGTKI4Emp778XAtirjRgKL0/JDwHlp+XRgVUT8LM1/KSJ2daJGMzPbQ/VeqD85Iu5rW4mI+4EPdeJ9JlIKjEqS+lEcldxZx7bWAGen5fOBYWl5JBCSHpC0QtJft/NekyUtk7TshZecOWZmOdUbKpslXSNpRHr9LbC5nonp+yxnAXfUGDYBWFI69VXLJcDU9MGBg4E3Uvu+wEeAC9LPcyWNrZwcEXMiojUiWgcf1qeeXTAzszrVGyqTgMEU1zAWAkfw9rfrOzIeWBERW2qMqXkkUxYRayPi9IgYk+ZsTF3PAYsj4sWI2AncB5xYZ41mZpZBvR8p3gZc0cX32O1aSSVJA4FTgAvr2ZikIyJiq6R9gGuA2anrAeCv06m0N9I2/6GLNZuZ/d7YkPnGuX1qpEFdoSLpEdK36csi4rQO5vUHxgGXldqmpLltYXAu8GBE7KiYO5/iE1yHS3oO+HJE3AxMknR5GrYAuDVt72VJ1wP/mWq9LyLurWf/zMwsD0W8IyveOUgaU1o9gOITV29GRNWL4b1F6+gD4vEHhnd3GWZmvUqfIeuXR0Rrtb56T39VfidliaTH97gyMzP7vVLv6a/yt+f3ofgy48CGVGRmZr1Wvbe+X05xnULAm8Am4NJGFWVmZr1TvaHy3oh4vdwgaf8G1GNmZr1Yvd9TebRK29KchZiZWe9X80hF0h8AQ4EDJZ1AcfoLYADQr8G1mZlZL9PR6a9PAF8AjgauL7W/CvxNg2oyM7NeqmaoRMQ8YJ6k8yKinps9mpnZXqyj018XRsR3gBGS/qqyPyKurzLNzMz2Uh2d/uqffh5Upa/jr+L3cE+8PJhjbr+s44F1yn1/HTOz3qaj019tvyV/GBFLyn2SPtywqszMrFeq9yPF/1Rnm5mZ7cU6uqbyQYonPA6uuKYyAPATrszMbDcdXVPpS3E9ZV+Kpyy22Q58ulFFmZlZ79TRNZVFwCJJcyPimSbVZGZmvVS99/7aKWkW8D6K56kAHT+ky8zM9i71Xqi/DVgLvAv4O+AXFE9YNDMze0u9oXJYepTvbyNiUURcAvgoxczMdlPv6a/fpp/PSzoT2AwMqjHezMz2QvWGyn+XNBC4kuL7KQOAaY0qyszMeqe6Tn9FxPcj4pWIWB0Rp0bEGOA9teZIGiVpZem1XdK0ijHTS/2rJe1qe3SxpFskbZW0umLOaElLJT0h6R5JAyr6h0t6TdJV9eybmZnlU+81lWrecYPJsohYFxEtEdFC8Uz7ncDCijGzSmNmAosiYlvqngucUWXTNwEzIuL4tL3pFf3XA/d3blfMzCyHek9/VaOOh7xlLLCxg++6TALmt61ExGJJI6qMGwksTssPAQ8AXwKQdA6wCdjRidrMzH6v5bxxbqH9E0F7cqTSmbsUT6QUGJUk9aM4KqnnmS1rgLPT8vnAsLSNg4CrKT7y3C5JkyUtk7Rs12vOHjOznGqGiqRX07WQyterwFH1vIGkvsBZwB01hk0AlpROfdVyCTBV0nKKW8e8kdq/AvxDRLxWa3JEzImI1oho7XNQ/1pDzcyskzq6TcvBtfrrNB5YERFbaoypeSRTUdNa4HQASSOBM1PXB4BPS/p74BDgd5Jej4gbulq4mZl1zp5cU6nXbtdKKqWPKp8CXFjPxiQdERFbJe0DXAPMBoiIj5bGfAV4zYFiZtZce3JNpUOS+gPjgAWltimSppSGnQs8GBE7KubOB5YCoyQ9J+nS1DVJ0tMUt43ZDNzayH0wM7P6NfRIJQXFYRVtsyvW51J8fLhy7qR2tvlN4JsdvO9XOlepmZnl0NAjFTMz27s4VMzMLBuHipmZZeNQMTOzbBwqZmaWjUPFzMyycaiYmVk2DhUzM8vGoWJmZtk4VMzMLBuHipmZZeNQMTOzbBwqZmaWjUPFzMyycaiYmVk2DhUzM8vGoWJmZtk4VMzMLBuHipmZZdOwZ9RLGgV8t9T0buDaiPhGacx04IJSLe8FBkfENkm3AJ8CtkbEH5fmjAZmAwcBvwAuiIjtksYB1wF9gTeA6RHxowbtnplZr7HhM/+cdXt9rmi/r2FHKhGxLiJaIqIFGAPsBBZWjJlVGjMTWBQR21L3XOCMKpu+CZgREcen7U1P7S8CE1L7RcC38+6RmZl1pFmnv8YCGyPimRpjJgHz21YiYjGwrcq4kcDitPwQcF4a/9OI2Jza1wAHStp/Tws3M7P6NStUJlIKjEqS+lEcldxZx7bWAGen5fOBYVXGnAesiIjfVHmvyZKWSVq267UddbydmZnVq+GhIqkvcBZwR41hE4AlpVNftVwCTJW0HDiY4vpJ+f3eB3wNuKza5IiYExGtEdHa56D+9eyCmZnVqWEX6kvGUxw1bKkxpuaRTFlErAVOB5A0EjizrU/S0RTXWT4fERu7XLGZmXVJM05/7XatpJKkgcApwF31bEzSEennPsA1FJ8EQ9IhwL0UF/GX7FnJZmbWFQ0NFUn9gXHAglLbFElTSsPOBR6MiB0Vc+cDS4FRkp6TdGnqmiTpaWAtsBm4NbX/OXAMcK2klel1REN2zMzMqlJEdHcN3Wb/4cNi6FXTsm0v92fBzcx6oj5D1i+PiNZqff5GvZmZZeNQMTOzbBwqZmaWjUPFzMyycaiYmVk2DhUzM8vGoWJmZtk4VMzMLBuHipmZZeNQMTOzbBwqZmaWjUPFzMyycaiYmVk2DhUzM8vGoWJmZtk4VMzMLJtmPKN+r3HM7Zd1dwnt8gPEzKwZfKRiZmbZOFTMzCybhoWKpFGSVpZe2yVNqxgzvdS/WtIuSYNS3y2StkpaXTFntKSlkp6QdI+kAaW+mZI2SFon6RON2jczM6uuYaESEesioiUiWoAxwE5gYcWYWaUxM4FFEbEtdc8Fzqiy6ZuAGRFxfNredABJxwETgfeleTdK6pN7v8zMrH3NOv01FtgYEc/UGDMJmN+2EhGLgW1Vxo0EFqflh4Dz0vLZwL9HxG8iYhOwAThpTws3M7P6NStUJlIKjEqS+lEcXdxZx7bWUAQIwPnAsLQ8FPhladxzqc3MzJqk4aEiqS9wFnBHjWETgCWlU1+1XAJMlbQcOBh4o5P1TJa0TNKyXa/t6MxUMzPrQDO+pzIeWBERW2qMqXkkUxYRa4HTASSNBM5MXb/i7aMWgKNTW+X8OcAcgP2HD4t63tPMzOrTjNNfu10rqSRpIHAKcFc9G5N0RPq5D3ANMDt13Q1MlLS/pHcBxwKP70HdZmbWSQ0NFUn9gXHAglLbFElTSsPOBR6MiB0Vc+cDS4FRkp6TdGnqmiTpaWAtsBm4FSAi1gC3A08CPwAuj4hdjdkzMzOrRhF77xmg/YcPi6FXTevuMprCt2kxs1z6DFm/PCJaq/X5G/VmZpaNQ8XMzLJxqJiZWTYOFTMzy8ahYmZm2ThUzMwsG4eKmZll41AxM7NsHCpmZpaNQ8XMzLJxqJiZWTYOFTMzy8ahYmZm2ThUzMwsm7361vcDNCg+oLHZtrfxGydn25aZWU+16YqrfOt7MzNrPIeKmZll41AxM7NsHCpmZpaNQ8XMzLJpWKhIGiVpZem1XdK0ijHTS/2rJe2SNCj13SJpq6TVFXNaJD2W5iyTdFJqHyjpHkk/k7RG0sWN2jczM6uuYaESEesioiUiWoAxwE5gYcWYWaUxM4FFEbEtdc8Fzqiy6b8H/i7NuTatA1wOPBkRo4GPAf9LUt+c+2RmZrU16/TXWGBjRDxTY8wkYH7bSkQsBrZVGRfAgLQ8ENhcaj9YkoCD0tw397BuMzPrhH2b9D4TKQVGJUn9KI5K/ryObU0DHpD0dYpQ/FBqvwG4myJkDgY+GxG/q/Jek4HJAAfQr/49MDOzDjX8SCWdgjoLuKPGsAnAktKpr1q+CPxlRAwD/hK4ObV/AlgJHAW0ADdIGlA5OSLmRERrRLTux/5174eZmXWsGae/xgMrImJLjTE1j2QqXAQsSMt3ACel5YuBBVHYAGwC/qgL9ZqZWRc1I1R2u1ZSSdJA4BTgrjq3tzmNBzgNWJ+Wn6W4doOkI4FRwM+7UK+ZmXVRQ28oKak/xS/7d0fEK6ltCkBEzE7rXwDOiIiJFXPnU3yK63BgC/DliLhZ0keAb1JcD3odmBoRyyUdRfGJsSGAgOsi4jsd1PcCUOvDA93hcODF7i6iE1xv4/SmWsH1NlpPqvcPI2JwtY69+i7FPZGkZe3d/bMncr2N05tqBdfbaL2lXn+j3szMsnGomJlZNg6VnmdOdxfQSa63cXpTreB6G61X1OtrKmZmlo2PVMzMLBuHipmZZeNQyay9W/ZXGfdfJL0p6dNpvUXS0nTb/lWSPlsae5ukdenxALdI2i+1X5DGPiHpUUmje2qt7W2rJ9cr6WPpEQtrJC3qyfXmePRDg+q9OdW0StL3JB2U2veX9F1JGyT9RNKIHl7vX0l6MrU/LOkPe3K9pf7zJIWk5n0UOSL8yvgC/gQ4EVhdY0wf4EfAfcCnU9tI4Ni0fBTwPHBIWv8kxRc6RXF3gi+m9g8Bh6bl8cBPemqt7W2rp9YLHAI8CQxP60f08Hr/BvhaWh5McZfuvj2g3gGludcDM9LyVGB2Wp4IfLeH/Pm2V++pQL+0/MWeXm9aPxhYDDwGtHa23q6+fKSSWbR/y/6yvwDuBLaW5j0dEevT8ubUNzit3xcJ8DhwdGp/NCJeTpt4rK29J9ba3rZ6cL1/RnEvuWfTuE7X3OR69/jRDw2qdztAquvAVCfA2cC8tPw9YGwa0yPrjYhHImJn2kSn/601u97kq8DXKO480jQOlSaTNBQ4F/hWjTEnAX2BjRXt+wGfA35QZdqlwP35Ks1baz3b6kn1Uvzv8FBJP5a0XNLne3i9NwDvpbg33hPAFVHl0Q/dUa+kW4FfU9zg9Z9S81DglwAR8SbwCnBYD663LPu/tfS+2eqVdCIwLCLuzV1nRxwqzfcN4Or2/sFLGgJ8G7i4ypgbgcUR8R8Vc06l+It+dQ+utea2Mqn5Hp2sd1+KJ5aeSfFYhS9JGtmD663r0Q/dUW9EXJzqegr4bLW5DfINMtcr6UKgFZjVU+uVtA/FqbArG1Bjx5p1nm1vegEjaOe8KcUt+X+RXq9RHMqek/oGACuocr0B+DLwf4B9KtrfT/G/lpE9udZa2+qh9c6geGx12/rNwPk9uN57gY+W1n8EnNQT6i3N/xPg+2n5AeCDaXlfihslqqfWm9Y/TvGLu9PX15pZL8UTcV8sbet1iiPYplxXafgb7I2vWn9xKsbN5e2LcX2Bh4FpVcb9V+BR4MCK9uHABuBDPb3W9rbVU+ulOJX0cPqF1w9YDfxxD673W8BX0vKRwK+Aw7uzXooPExxTWv468PW0fjm7X6i/vbv/PnRQ7wkU/3k7tit1NrveinE/pokX6pv1OOG9hkq37Jf0HMX/KveDt2/3347PUPxP4zAVjwMA+EJErARmU9yif2m6lrkgIv4bcC3FeegbU/ub0Ym7mDa51j3WzHoj4ilJPwBWAb8DboqImh8F7c56KS7KzpX0BMUvmKsjolO3Sc9dL8Wf3bx0Gk7Azyg+OQXFkd+3JW2guHi926MvemC9syg+AHFH+nN/NiLO6sH1dhvfpsXMzLLxhXozM8vGoWJmZtk4VMzMLBuHipmZZeNQMTOzbBwqZl0k6bUGb/8+SYek19QuzP+YpO83ojaz9jhUzHqoiPhkRPw/ijsmdzpUzLqDQ8Uso/Tsi8fS8y0WSjo0tf9Y0tckPS7paUkfTe39JN2entWxUMWzRVpT3y8kHQ5cB7xHxbNdZlUegUi6oe1LcZLOkLRW0grgT0tj+qt4nsfjkn4q6ezm/anY3sShYpbXv1J8m/39FHcL/nKpb9+IOAmYVmqfCrwcEccBX6K4iWWlGcDGiGiJiOntvbGkA4B/ASak7fxBqftvgR+l9z8VmCWpfxf2z6wmh4pZJpIGUjw8qe0pkfMobq/RZkH6uZziHlAAHwH+HSDdBmbVHpTwR8CmiFgfxa0yvlPqOx2YIWklxb2gDqC4d5xZVr73l1nz/Cb93MWe/dt7k93/Q3hAHXMEnBcR6/bgfc065CMVs0wi4hXg5bbrJRQP0ero2fZLKG4YiKTjgOOrjHmV4tGwbZ4BjlPxnPdDgLGpfS0wQtJ70vqk0pwHgL9ITwhE0gl17ZRZJ/lIxazr+qW7zba5HrgImC2pH/Bz4OIOtnEjxZ1mn6QIhTUUT0F8S0S8JGmJpNXA/RExXdLtFLfj3wT8NI17XdJk4F5JO4H/4O0w+irFQ6BWpYc4bQI+1cX9NmuX71Js1o0k9QH2S4HwHuCHwKiIeKObSzPrEh+pmHWvfsAj6ZnzAqY6UKw385GKmZll4wv1ZmaWjUPFzMyycaiYmVk2DhUzM8vGoWJmZtn8f5FWDQchMNT4AAAAAElFTkSuQmCC", - "text/plain": [ - "
" + "cell_type": "markdown", + "id": "c03e4273", + "metadata": { + "id": "c03e4273" + }, + "source": [ + "## Make predictions on Togo test set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f47915e", + "metadata": { + "id": "0f47915e" + }, + "outputs": [], + "source": [ + "test_preds, test_instances = [], []\n", + "for _, test_instance in togo_dataset.test_data(flatten_x=True):\n", + " test_preds.append(model.predict_proba(test_instance.x)[:, 1])\n", + " test_instances.append(test_instance)\n", + " \n", + "print(\n", + " f\"For the Random Forest classifier, \"\n", + " f\"{test_instances[0].evaluate_predictions(test_preds[0])}, \"\n", + ")\n", + "\n", + "metrics = test_instances[0].evaluate_predictions(test_preds[0])\n", + "assert metrics[\"f1_score\"] > 0.73, \"Default model f1-score should be greater than 0.73\"\n", + "assert metrics[\"auc_roc\"] > 0.88, \"Default model AUC-ROC should be greater than 0.88\"" + ] + }, + { + "cell_type": "markdown", + "id": "b9b06817", + "metadata": { + "id": "b9b06817" + }, + "source": [ + "## Get test file for inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e5e60cd", + "metadata": { + "id": "8e5e60cd" + }, + "outputs": [], + "source": [ + "test_file = \"98-togo_2019-02-06_2020-02-01.tif\"\n", + "\n", + "temp_dir = tempfile.gettempdir()\n", + "p = Path(temp_dir) / test_file\n", + "response = requests.get(\n", + " f\"https://github.com/nasaharvest/cropharvest/blob/main/test/cropharvest/{test_file}?raw=true\", \n", + ")\n", + "with p.open(\"wb\") as f:\n", + " f.write(response.content)" + ] + }, + { + "cell_type": "markdown", + "id": "e61d7ae0", + "metadata": { + "id": "e61d7ae0" + }, + "source": [ + "## Run inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2402cbb4", + "metadata": { + "scrolled": false, + "id": "2402cbb4" + }, + "outputs": [], + "source": [ + "preds = Inference(model=model, normalizing_dict=None).run(p)\n", + "\n", + "# Check size\n", + "assert preds.dims[\"lat\"] == 17\n", + "assert preds.dims[\"lon\"] == 17\n", + "\n", + "# Check all predictions between 0 and 1\n", + "assert preds.min() >= 0\n", + "assert preds.max() <= 1\n", + "\n", + "preds" + ] + }, + { + "cell_type": "markdown", + "id": "98bf4ee3", + "metadata": { + "id": "98bf4ee3" + }, + "source": [ + "## [Optional] Visualize model prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33f80c45", + "metadata": { + "id": "33f80c45" + }, + "outputs": [], + "source": [ + "!pip install matplotlib -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d559099", + "metadata": { + "id": "0d559099" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c1f182b", + "metadata": { + "id": "6c1f182b" + }, + "outputs": [], + "source": [ + "preds_np = preds.to_array()[0]\n", + "plt.pcolormesh(preds_np.lon, preds_np.lat, preds_np.data)\n", + "plt.xlabel(\"Longitude\")\n", + "plt.ylabel(\"Latitude\");" ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" } - ], - "source": [ - "preds_np = preds.to_array()[0]\n", - "plt.pcolormesh(preds_np.lon, preds_np.lat, preds_np.data)\n", - "plt.xlabel(\"Longitude\")\n", - "plt.ylabel(\"Latitude\");" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c9c519ac", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.12" + }, + "colab": { + "name": "demo.ipynb", + "provenance": [] + } }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file