From 6db6c060f0128f02f40094b8f6c8af19c5337d9f Mon Sep 17 00:00:00 2001 From: Jelmer Kuperus Date: Sat, 30 Dec 2023 10:54:28 +0100 Subject: [PATCH] move spark code to its own project --- .github/workflows/ci.yml | 25 +- .github/workflows/publish.yml | 23 +- .github/workflows/release.yml | 5 + .gitignore | 9 - .run/Template ScalaTest.run.xml | 16 - README.md | 4 +- build.sbt | 81 +- .../github/jelmerk/knn/util/ArrayBitSet.java | 1 + .../quick_start_google_colab.ipynb | 447 -------- .../similarity.ipynb | 601 ---------- .../hnswlib-examples-pyspark-luigi/README.md | 19 - .../bruteforce_index.py | 35 - .../hnswlib-examples-pyspark-luigi/convert.py | 34 - .../evaluate_performance.py | 50 - .../hnswlib-examples-pyspark-luigi/flow.py | 327 ------ .../hnsw_index.py | 37 - .../hnswlib-examples-pyspark-luigi/query.py | 35 - hnswlib-spark/README.md | 151 --- hnswlib-spark/run-pyspark-tests.sh | 24 - hnswlib-spark/src/main/python/__init__.py | 0 hnswlib-spark/src/main/python/com/__init__.py | 0 .../src/main/python/com/github/__init__.py | 0 .../python/com/github/jelmerk/__init__.py | 0 .../com/github/jelmerk/spark/__init__.py | 0 .../jelmerk/spark/conversion/__init__.py | 4 - .../com/github/jelmerk/spark/knn/__init__.py | 0 .../jelmerk/spark/knn/bruteforce/__init__.py | 4 - .../jelmerk/spark/knn/evaluation/__init__.py | 4 - .../github/jelmerk/spark/knn/hnsw/__init__.py | 4 - .../github/jelmerk/spark/linalg/__init__.py | 4 - .../src/main/python/pyspark_hnsw/__init__.py | 189 --- .../main/python/pyspark_hnsw/conversion.py | 35 - .../main/python/pyspark_hnsw/evaluation.py | 54 - .../src/main/python/pyspark_hnsw/knn.py | 570 --------- .../src/main/python/pyspark_hnsw/linalg.py | 32 - .../spark/HnswLibKryoRegistrator.scala | 18 - .../spark/conversion/VectorConverter.scala | 119 -- .../jelmerk/spark/knn/KnnAlgorithm.scala | 1015 ----------------- .../knn/bruteforce/BruteForceSimilarity.scala | 103 -- .../evaluation/KnnSimilarityEvaluator.scala | 119 -- .../spark/knn/hnsw/HnswSimilarity.scala | 181 --- .../com/github/jelmerk/spark/knn/knn.scala | 226 ---- .../jelmerk/spark/linalg/Normalizer.scala | 96 -- .../DenseVectorDistanceFunctions.scala | 78 -- .../SparseVectorDistanceFunctions.scala | 339 ------ .../functions/VectorDistanceFunctions.scala | 111 -- .../util/SerializableConfiguration.scala | 20 - .../com/github/jelmerk/spark/util/Utils.scala | 32 - hnswlib-spark/src/test/python/conftest.py | 18 - .../src/test/python/test_bruteforce.py | 23 - hnswlib-spark/src/test/python/test_hnsw.py | 22 - .../src/test/python/test_integration.py | 29 - .../src/test/python/test_knn_evaluator.py | 21 - .../src/test/python/test_normalizer.py | 14 - .../src/test/python/test_vector_converter.py | 14 - .../jelmerk/spark/SharedSparkContext.scala | 43 - .../conversion/VectorConverterSpec.scala | 64 -- .../jelmerk/spark/knn/SerializerSpec.scala | 57 - .../KnnSimilarityEvaluatorSpec.scala | 42 - .../spark/knn/hnsw/HnswSimilaritySpec.scala | 255 ----- .../jelmerk/spark/linalg/NormalizerSpec.scala | 45 - .../SparseVectorDistanceFunctionsSpec.scala | 51 - project/plugins.sbt | 4 +- scripts/colab_setup.sh | 43 - scripts/kaggle_setup.sh | 41 - 65 files changed, 29 insertions(+), 6038 deletions(-) delete mode 100644 .run/Template ScalaTest.run.xml delete mode 100644 hnswlib-examples/hnswlib-examples-pyspark-google-colab/quick_start_google_colab.ipynb delete mode 100644 hnswlib-examples/hnswlib-examples-pyspark-jupyter-notebook/similarity.ipynb delete mode 100644 hnswlib-examples/hnswlib-examples-pyspark-luigi/README.md delete mode 100644 hnswlib-examples/hnswlib-examples-pyspark-luigi/bruteforce_index.py delete mode 100644 hnswlib-examples/hnswlib-examples-pyspark-luigi/convert.py delete mode 100644 hnswlib-examples/hnswlib-examples-pyspark-luigi/evaluate_performance.py delete mode 100644 hnswlib-examples/hnswlib-examples-pyspark-luigi/flow.py delete mode 100644 hnswlib-examples/hnswlib-examples-pyspark-luigi/hnsw_index.py delete mode 100644 hnswlib-examples/hnswlib-examples-pyspark-luigi/query.py delete mode 100644 hnswlib-spark/README.md delete mode 100755 hnswlib-spark/run-pyspark-tests.sh delete mode 100644 hnswlib-spark/src/main/python/__init__.py delete mode 100644 hnswlib-spark/src/main/python/com/__init__.py delete mode 100644 hnswlib-spark/src/main/python/com/github/__init__.py delete mode 100644 hnswlib-spark/src/main/python/com/github/jelmerk/__init__.py delete mode 100644 hnswlib-spark/src/main/python/com/github/jelmerk/spark/__init__.py delete mode 100644 hnswlib-spark/src/main/python/com/github/jelmerk/spark/conversion/__init__.py delete mode 100644 hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/__init__.py delete mode 100644 hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/bruteforce/__init__.py delete mode 100644 hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/evaluation/__init__.py delete mode 100644 hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/hnsw/__init__.py delete mode 100644 hnswlib-spark/src/main/python/com/github/jelmerk/spark/linalg/__init__.py delete mode 100644 hnswlib-spark/src/main/python/pyspark_hnsw/__init__.py delete mode 100644 hnswlib-spark/src/main/python/pyspark_hnsw/conversion.py delete mode 100644 hnswlib-spark/src/main/python/pyspark_hnsw/evaluation.py delete mode 100644 hnswlib-spark/src/main/python/pyspark_hnsw/knn.py delete mode 100644 hnswlib-spark/src/main/python/pyspark_hnsw/linalg.py delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/HnswLibKryoRegistrator.scala delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/conversion/VectorConverter.scala delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/evaluation/KnnSimilarityEvaluator.scala delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/Normalizer.scala delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/functions/DenseVectorDistanceFunctions.scala delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/functions/SparseVectorDistanceFunctions.scala delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/functions/VectorDistanceFunctions.scala delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/util/SerializableConfiguration.scala delete mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/util/Utils.scala delete mode 100644 hnswlib-spark/src/test/python/conftest.py delete mode 100644 hnswlib-spark/src/test/python/test_bruteforce.py delete mode 100644 hnswlib-spark/src/test/python/test_hnsw.py delete mode 100644 hnswlib-spark/src/test/python/test_integration.py delete mode 100644 hnswlib-spark/src/test/python/test_knn_evaluator.py delete mode 100644 hnswlib-spark/src/test/python/test_normalizer.py delete mode 100644 hnswlib-spark/src/test/python/test_vector_converter.py delete mode 100644 hnswlib-spark/src/test/scala/com/github/jelmerk/spark/SharedSparkContext.scala delete mode 100644 hnswlib-spark/src/test/scala/com/github/jelmerk/spark/conversion/VectorConverterSpec.scala delete mode 100644 hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/SerializerSpec.scala delete mode 100644 hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/evaluation/KnnSimilarityEvaluatorSpec.scala delete mode 100644 hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilaritySpec.scala delete mode 100644 hnswlib-spark/src/test/scala/com/github/jelmerk/spark/linalg/NormalizerSpec.scala delete mode 100644 hnswlib-spark/src/test/scala/com/github/jelmerk/spark/linalg/functions/SparseVectorDistanceFunctionsSpec.scala delete mode 100644 scripts/colab_setup.sh delete mode 100644 scripts/kaggle_setup.sh diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9edc2d87..257e8096 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,5 +1,8 @@ name: CI pipeline +permissions: + checks: write + on: pull_request: paths: @@ -9,38 +12,24 @@ on: - '*' tags-ignore: - 'v[0-9]+.[0-9]+.[0-9]+' + paths-ignore: + - '**.md' jobs: ci-pipeline: runs-on: ubuntu-22.04 - strategy: - fail-fast: false - matrix: - spark: - - 2.4.8 - - 3.0.2 - - 3.1.3 - - 3.2.4 - - 3.3.2 - - 3.4.1 - - 3.5.0 + env: ENV: 'ci' - SPARK_VERSION: ${{ matrix.spark }} steps: - name: Checkout uses: actions/checkout@v3 with: fetch-depth: 0 - - uses: actions/setup-python@v5 - with: - python-version: | - 3.7 - 3.9 - name: Build and test run: | - sbt -java-home "$JAVA_HOME_17_X64" clean +test -DsparkVersion="$SPARK_VERSION" + sbt -java-home "$JAVA_HOME_17_X64" clean +test - name: Publish Unit test results uses: mikepenz/action-junit-report@v4 with: diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index d59a51d4..249d06c5 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -1,5 +1,8 @@ name: Publish pipeline +permissions: + contents: read + on: workflow_dispatch: @@ -10,21 +13,9 @@ on: jobs: publish-artifacts: runs-on: ubuntu-22.04 - strategy: - fail-fast: false - matrix: - spark: - - 2.4.8 - - 3.0.2 - - 3.1.3 - - 3.2.4 - - 3.3.2 - - 3.4.1 - - 3.5.0 env: ENV: 'ci' - SPARK_VERSION: ${{ matrix.spark }} NEXUS_USER: ${{ secrets.NEXUS_USER }} NEXUS_PASSWORD: ${{ secrets.NEXUS_PASSWORD }} @@ -33,11 +24,6 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 - - uses: actions/setup-python@v5 - with: - python-version: | - 3.7 - 3.9 - name: Import GPG Key uses: crazy-max/ghaction-import-gpg@v1 env: @@ -45,4 +31,5 @@ jobs: PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} - name: Publish artifacts run: | - sbt -java-home "$JAVA_HOME_17_X64" clean +publishSigned -DsparkVersion="$SPARK_VERSION" + sbt -java-home "$JAVA_HOME_17_X64" clean +publishSigned + sbt -java-home "$JAVA_HOME_17_X64" sonatypeBundleRelease diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 201abf5b..6cdbc301 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,5 +1,8 @@ name: Release pipeline +permissions: + contents: write + on: workflow_dispatch: inputs: @@ -13,6 +16,8 @@ jobs: steps: - name: Checkout main branch uses: actions/checkout@v3 + with: + token: ${{ secrets.RELEASE_TOKEN }} - name: Release run: | git config --global user.email "action@github.com" diff --git a/.gitignore b/.gitignore index a68f2928..9e0adaa6 100644 --- a/.gitignore +++ b/.gitignore @@ -15,12 +15,3 @@ dist/ # MacOS .DS_Store - -# Virtual env -.venv - -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class -.pytest_cache/ \ No newline at end of file diff --git a/.run/Template ScalaTest.run.xml b/.run/Template ScalaTest.run.xml deleted file mode 100644 index 412f33d1..00000000 --- a/.run/Template ScalaTest.run.xml +++ /dev/null @@ -1,16 +0,0 @@ - - - - - diff --git a/README.md b/README.md index 1bc9bc96..be546cbf 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,9 @@ The following distance metrics are currently pre-packaged : - inner product - manhattan distance -It comes with [spark integration](https://github.com/jelmerk/hnswlib/tree/master/hnswlib-spark), [pyspark integration](https://github.com/jelmerk/hnswlib/tree/master/hnswlib-pyspark) and a [scala wrapper](https://github.com/jelmerk/hnswlib/tree/master/hnswlib-scala) that should feel native to scala developers +It comes with a [scala wrapper](https://github.com/jelmerk/hnswlib/tree/master/hnswlib-scala) that should feel native to scala developers + +Apache spark support was moved into the [hnswlib-spark](https://github.com/jelmerk/hnswlib-spark) project. To find out more about how to use this library take a look at the [hnswlib-examples](https://github.com/jelmerk/hnswlib/tree/master/hnswlib-examples) module or browse the documentation in the readme files of the submodules diff --git a/build.sbt b/build.sbt index d6552ccf..0e578058 100644 --- a/build.sbt +++ b/build.sbt @@ -1,6 +1,3 @@ -import Path.relativeTo -import sys.process.* - ThisBuild / organization := "com.github.jelmerk" ThisBuild / scalaVersion := "2.12.18" @@ -15,15 +12,8 @@ ThisBuild / Compile / doc / javacOptions ++= { Seq("-Xdoclint:none") } -val java8Home = sys.env.getOrElse("JAVA_HOME_8_X64", s"${sys.props("user.home")}/.sdkman/candidates/java/8.0.382-amzn") - lazy val publishSettings = Seq( pomIncludeRepository := { _ => false }, - publishTo := { - val nexus = "https://oss.sonatype.org/" - if (isSnapshot.value) Some("snapshots" at nexus + "content/repositories/snapshots") - else Some("releases" at nexus + "service/local/staging/deploy/maven2") - }, licenses := Seq("Apache License 2.0" -> url("http://www.apache.org/licenses/LICENSE-2.0.html")), @@ -43,7 +33,11 @@ lazy val publishSettings = Seq( "oss.sonatype.org", sys.env.getOrElse("NEXUS_USER", ""), sys.env.getOrElse("NEXUS_PASSWORD", "") - ) + ), + + publishTo := sonatypePublishToBundle.value, + sonatypeSessionName := s"[sbt-sonatype] ${name.value} ${version.value}" + ) lazy val noPublishSettings = @@ -54,13 +48,8 @@ val junitVersion = "5.5.2" val hamcrestVersion = "2.1" val mockitoVersion = "3.0.0" -val sparkVersion = settingKey[String]("Spark version") - -lazy val pyTest = taskKey[Unit]("Run the python tests") -lazy val pyPublish = taskKey[Unit]("Publish the python sources to a pypi repo") - lazy val root = (project in file(".")) - .aggregate(hnswlibUtils, hnswlibCore, hnswlibCoreJdk17, hnswlibMetricsDropwizard, hnswlibScala, hnswlibSpark) + .aggregate(hnswlibUtils, hnswlibCore, hnswlibCoreJdk17, hnswlibMetricsDropwizard, hnswlibScala) .settings(noPublishSettings) lazy val hnswlibUtils = (project in file("hnswlib-utils")) @@ -168,62 +157,4 @@ lazy val hnswlibScala = (project in file("hnswlib-scala")) libraryDependencies ++= Seq( "org.scalatest" %% "scalatest" % scalaTestVersion % Test ) - ) - -lazy val hnswlibSpark = (project in file("hnswlib-spark")) - .dependsOn(hnswlibUtils) - .dependsOn(hnswlibScala) - .settings( - name := s"hnswlib-spark_${sparkVersion.value.split('.').take(2).mkString("-")}", - publishSettings, - crossScalaVersions := { - if (sparkVersion.value >= "3.2.0") { - Seq("2.12.18", "2.13.10") - } else if (sparkVersion.value >= "3.0.0") { - Seq("2.12.18") - } else { - Seq("2.12.18", "2.11.12") - } - }, - javaHome := Some(file(java8Home)), - Compile / unmanagedSourceDirectories += baseDirectory.value / "src" / "main" / "python", - Test / unmanagedSourceDirectories += baseDirectory.value / "src" / "test" / "python", - Compile / packageBin / mappings ++= { - val base = baseDirectory.value / "src" / "main" / "python" - val srcs = base ** "*.py" - srcs pair relativeTo(base) - }, - assembly / mainClass := None, - assembly / assemblyOption ~= { - _.withIncludeScala(false) - }, - sparkVersion := sys.props.getOrElse("sparkVersion", "3.3.2"), - pyTest := { - val log = streams.value.log - - val artifactPath = (Compile / assembly).value.getAbsolutePath - if (scalaVersion.value == "2.12.18" && sparkVersion.value >= "3.0.0" || scalaVersion.value == "2.11.12") { - val pythonVersion = if (scalaVersion.value == "2.11.12") "python3.7" else "python3.9" - val ret = Process( - Seq("./run-pyspark-tests.sh", sparkVersion.value, pythonVersion), - cwd = baseDirectory.value, - extraEnv = "JAVA_HOME" -> java8Home, "ARTIFACT_PATH" -> artifactPath - ).! - require(ret == 0, "Python tests failed") - } else { - // pyspark packages support just one version of scala. You cannot use 2.13.x because it ships with 2.12.x jars - log.info(s"Running pyTests for Scala ${scalaVersion.value} and Spark ${sparkVersion.value} is not supported.") - } - }, - test := { - (Test / test).value - (Test / pyTest).value - }, - pyTest := pyTest.dependsOn(assembly).value, - libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-hive" % sparkVersion.value % Provided, - "org.apache.spark" %% "spark-mllib" % sparkVersion.value % Provided, - "com.holdenkarau" %% "spark-testing-base" % s"${sparkVersion.value}_1.4.7" % Test, - "org.scalatest" %% "scalatest" % scalaTestVersion % Test - ) ) \ No newline at end of file diff --git a/hnswlib-core/src/main/java/com/github/jelmerk/knn/util/ArrayBitSet.java b/hnswlib-core/src/main/java/com/github/jelmerk/knn/util/ArrayBitSet.java index b1f46447..2609b2fc 100644 --- a/hnswlib-core/src/main/java/com/github/jelmerk/knn/util/ArrayBitSet.java +++ b/hnswlib-core/src/main/java/com/github/jelmerk/knn/util/ArrayBitSet.java @@ -1,5 +1,6 @@ package com.github.jelmerk.knn.util; +import java.io.Serial; import java.io.Serializable; import java.util.Arrays; diff --git a/hnswlib-examples/hnswlib-examples-pyspark-google-colab/quick_start_google_colab.ipynb b/hnswlib-examples/hnswlib-examples-pyspark-google-colab/quick_start_google_colab.ipynb deleted file mode 100644 index 31df0359..00000000 --- a/hnswlib-examples/hnswlib-examples-pyspark-google-colab/quick_start_google_colab.ipynb +++ /dev/null @@ -1,447 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "hnswlib.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# HnswLib Quick Start\n", - "\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jelmerk/hnswlib/blob/master/hnswlib-examples/hnswlib-examples-pyspark-google-colab/quick_start_google_colab.ipynb)\n", - "\n", - "We will first set up the runtime environment and give it a quick test" - ], - "metadata": { - "id": "NtnuPdiDyN8_" - } - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "F0u73ufErwpG", - "outputId": "15bde5ea-bdb7-4e23-d74d-75f4ff851fab" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--2022-01-08 02:32:40-- https://raw.githubusercontent.com/jelmerk/hnswlib/master/scripts/colab_setup.sh\n", - "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...\n", - "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 1269 (1.2K) [text/plain]\n", - "Saving to: ‘STDOUT’\n", - "\n", - "- 100%[===================>] 1.24K --.-KB/s in 0s \n", - "\n", - "2022-01-08 02:32:41 (73.4 MB/s) - written to stdout [1269/1269]\n", - "\n", - "setup Colab for PySpark 3.0.3 and Hnswlib 1.0.0\n", - "Installing PySpark 3.0.3 and Hnswlib 1.0.0\n", - "\u001b[K |████████████████████████████████| 209.1 MB 73 kB/s \n", - "\u001b[K |████████████████████████████████| 198 kB 80.2 MB/s \n", - "\u001b[?25h Building wheel for pyspark (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!wget https://raw.githubusercontent.com/jelmerk/hnswlib/master/scripts/colab_setup.sh -O - | bash" - ] - }, - { - "cell_type": "code", - "source": [ - "import pyspark_hnsw\n", - "\n", - "from pyspark.ml import Pipeline\n", - "from pyspark_hnsw.knn import *\n", - "from pyspark.ml.feature import HashingTF, IDF, Tokenizer\n", - "from pyspark.sql.functions import col, posexplode" - ], - "metadata": { - "id": "nO6TiznusZ2y" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "spark = pyspark_hnsw.start()" - ], - "metadata": { - "id": "Y9KKKcZlscZF" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "print(\"Hnswlib version: {}\".format(pyspark_hnsw.version()))\n", - "print(\"Apache Spark version: {}\".format(spark.version))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "CJ2xbiCosydF", - "outputId": "baa771e6-5761-4a4d-fc26-22044aa6aeb5" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Hnswlib version: 1.0.0\n", - "Apache Spark version: 3.0.3\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "Load the product data from the [instacart market basket analysis kaggle competition ](https://www.kaggle.com/c/instacart-market-basket-analysis/data?select=products.csv.zip)" - ], - "metadata": { - "id": "nIYBMlF9i6cR" - } - }, - { - "cell_type": "code", - "source": [ - "!wget -O /tmp/products.csv \"https://drive.google.com/uc?export=download&id=1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz\"" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "hOBkUPYO1Zpa", - "outputId": "f003f2ee-bb8c-4b56-a475-a980c992d9da" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--2022-01-08 03:58:45-- https://drive.google.com/uc?export=download&id=1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz\n", - "Resolving drive.google.com (drive.google.com)... 173.194.79.100, 173.194.79.102, 173.194.79.101, ...\n", - "Connecting to drive.google.com (drive.google.com)|173.194.79.100|:443... connected.\n", - "HTTP request sent, awaiting response... 302 Moved Temporarily\n", - "Location: https://doc-10-b4-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4nf11kob2m4ai6bvlueodufo0oocm0t2/1641614325000/16131524327083715076/*/1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz?e=download [following]\n", - "Warning: wildcards not supported in HTTP.\n", - "--2022-01-08 03:58:45-- https://doc-10-b4-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4nf11kob2m4ai6bvlueodufo0oocm0t2/1641614325000/16131524327083715076/*/1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz?e=download\n", - "Resolving doc-10-b4-docs.googleusercontent.com (doc-10-b4-docs.googleusercontent.com)... 108.177.127.132, 2a00:1450:4013:c07::84\n", - "Connecting to doc-10-b4-docs.googleusercontent.com (doc-10-b4-docs.googleusercontent.com)|108.177.127.132|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 2166953 (2.1M) [text/csv]\n", - "Saving to: ‘/tmp/products.csv’\n", - "\n", - "/tmp/products.csv 100%[===================>] 2.07M --.-KB/s in 0.01s \n", - "\n", - "2022-01-08 03:58:45 (159 MB/s) - ‘/tmp/products.csv’ saved [2166953/2166953]\n", - "\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "productData = spark.read.option(\"header\", \"true\").csv(\"/tmp/products.csv\")" - ], - "metadata": { - "id": "oKodvLC6xwO6" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "productData.count()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q4C7HS1LQDcE", - "outputId": "f0b73205-ae29-4218-bd0e-81eb89fc3c4e" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "49688" - ] - }, - "metadata": {}, - "execution_count": 22 - } - ] - }, - { - "cell_type": "code", - "source": [ - "tokenizer = Tokenizer(inputCol=\"product_name\", outputCol=\"words\")\n", - "hashingTF = HashingTF(inputCol=\"words\", outputCol=\"rawFeatures\")\n", - "idf = IDF(inputCol=\"rawFeatures\", outputCol=\"features\")" - ], - "metadata": { - "id": "Zq2yRJevnRGS" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Create a simple TF / IDF model that turns product names into sparse word vectors and adds them to an exact knn index. \n", - "\n", - "An exact or brute force index will give 100% correct, will be quick to index but really slow to query and is only appropriate during development or for doing comparissons against an approximate index" - ], - "metadata": { - "id": "S3OkoohFo2IA" - } - }, - { - "cell_type": "code", - "source": [ - "bruteforce = BruteForceSimilarity(identifierCol='product_id', queryIdentifierCol='product_id', k = 5, featuresCol='features', distanceFunction='cosine', excludeSelf=True, numPartitions=10)" - ], - "metadata": { - "id": "ReyTZSM1uT2q" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "exact_pipeline = Pipeline(stages=[tokenizer, hashingTF, idf, bruteforce])" - ], - "metadata": { - "id": "20wtg6ZhHpwx" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "exact_model = exact_pipeline.fit(productData)" - ], - "metadata": { - "id": "Ln1aIrdyJRoL" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Next create the same model but add the TF / IDF vectors to a HNSW index" - ], - "metadata": { - "id": "cot3ByIOpwwZ" - } - }, - { - "cell_type": "code", - "source": [ - "hnsw = HnswSimilarity(identifierCol='product_id', queryIdentifierCol='product_id', featuresCol='features',\n", - " distanceFunction='cosine', numPartitions=10, excludeSelf=True, k = 5)" - ], - "metadata": { - "id": "7zLQLVreqWRM" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "hnsw_pipeline = Pipeline(stages=[tokenizer, hashingTF, idf, hnsw])" - ], - "metadata": { - "id": "mUlvwo89qEJm" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "hnsw_model = hnsw_pipeline.fit(productData)" - ], - "metadata": { - "id": "dwOkEFmxqeR2" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Select a record to query" - ], - "metadata": { - "id": "MQSYgEgHlg65" - } - }, - { - "cell_type": "code", - "source": [ - "queries = productData.filter(col(\"product_id\") == 43572)" - ], - "metadata": { - "id": "vCag3tH-NUf-" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "queries.show(truncate = False)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "pcUCCFxzQ02H", - "outputId": "8721ba75-f5d2-493e-a36c-d182e97a3bd0" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "+----------+-----------------------------+--------+-------------+\n", - "|product_id|product_name |aisle_id|department_id|\n", - "+----------+-----------------------------+--------+-------------+\n", - "|43572 |Alcaparrado Manzanilla Olives|110 |13 |\n", - "+----------+-----------------------------+--------+-------------+\n", - "\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "Show the results from the exact model" - ], - "metadata": { - "id": "qbcUGq4irTFH" - } - }, - { - "cell_type": "code", - "source": [ - "exact_model.transform(queries) \\\n", - " .select(posexplode(col(\"prediction\")).alias(\"pos\", \"item\")) \\\n", - " .select(col(\"pos\"), col(\"item.neighbor\").alias(\"product_id\"), col(\"item.distance\").alias(\"distance\")) \\\n", - " .join(productData, [\"product_id\"]) \\\n", - " .show(truncate=False)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q4wi29adOLRX", - "outputId": "1b06735b-8db4-4c4f-fe16-7d1aad02ea6d" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "+----------+---+-------------------+----------------------------------+--------+-------------+\n", - "|product_id|pos|distance |product_name |aisle_id|department_id|\n", - "+----------+---+-------------------+----------------------------------+--------+-------------+\n", - "|27806 |0 |0.2961162117528633 |Manzanilla Olives |110 |13 |\n", - "|25125 |1 |0.40715716898722976|Stuffed Manzanilla Olives |110 |13 |\n", - "|16721 |2 |0.40715716898722976|Manzanilla Stuffed Olives |110 |13 |\n", - "|39833 |3 |0.49516580877903393|Pimiento Sliced Manzanilla Olives |110 |13 |\n", - "|33495 |4 |0.514201828085252 |Manzanilla Pimiento Stuffed Olives|110 |13 |\n", - "+----------+---+-------------------+----------------------------------+--------+-------------+\n", - "\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "Show the results from the hnsw model" - ], - "metadata": { - "id": "JxHQ10aAr0MQ" - } - }, - { - "cell_type": "code", - "source": [ - "hnsw_model.transform(queries) \\\n", - " .select(posexplode(col(\"prediction\")).alias(\"pos\", \"item\")) \\\n", - " .select(col(\"pos\"), col(\"item.neighbor\").alias(\"product_id\"), col(\"item.distance\").alias(\"distance\")) \\\n", - " .join(productData, [\"product_id\"]) \\\n", - " .show(truncate=False)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "PupolEF6P0jc", - "outputId": "9c0ce36d-32ae-4494-d277-6a7246a17588" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "+----------+---+-------------------+----------------------------------+--------+-------------+\n", - "|product_id|pos|distance |product_name |aisle_id|department_id|\n", - "+----------+---+-------------------+----------------------------------+--------+-------------+\n", - "|27806 |0 |0.2961162117528633 |Manzanilla Olives |110 |13 |\n", - "|25125 |1 |0.40715716898722976|Stuffed Manzanilla Olives |110 |13 |\n", - "|16721 |2 |0.40715716898722976|Manzanilla Stuffed Olives |110 |13 |\n", - "|33495 |3 |0.514201828085252 |Manzanilla Pimiento Stuffed Olives|110 |13 |\n", - "|41472 |4 |0.514201828085252 |Pimiento Stuffed Manzanilla Olives|110 |13 |\n", - "+----------+---+-------------------+----------------------------------+--------+-------------+\n", - "\n" - ] - } - ] - } - ] -} \ No newline at end of file diff --git a/hnswlib-examples/hnswlib-examples-pyspark-jupyter-notebook/similarity.ipynb b/hnswlib-examples/hnswlib-examples-pyspark-jupyter-notebook/similarity.ipynb deleted file mode 100644 index ff2d6de4..00000000 --- a/hnswlib-examples/hnswlib-examples-pyspark-jupyter-notebook/similarity.ipynb +++ /dev/null @@ -1,601 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Jupyter hnswlib example\n", - "\n", - "This notebook demonstrates how to use hnswlib with pyspark in a jupyter notebook" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Download data \n", - "\n", - "The first line is dropped for it to be readable as csv input" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!curl \"https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz\" | gunzip | tail -n +2 | gzip > data.gz " - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from pyspark.ml import Pipeline\n", - "from pyspark.ml.feature import VectorAssembler\n", - "from pyspark_hnsw.conversion import VectorConverter\n", - "from pyspark_hnsw.knn import *\n", - "from pyspark_hnsw.linalg import Normalizer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Read the data as a spark dataframe" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "words_df = spark.read \\\n", - " .option('delimiter', ' ') \\\n", - " .option('inferSchema', 'true') \\\n", - " .option(\"quote\", \"\\u0000\") \\\n", - " .csv('data.gz') \\\n", - " .withColumnRenamed('_c0', 'id')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Inspect the schema" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "root\n", - " |-- id: string (nullable = true)\n", - " |-- _c1: double (nullable = true)\n", - " |-- _c2: double (nullable = true)\n", - " |-- _c3: double (nullable = true)\n", - " |-- _c4: double (nullable = true)\n", - " |-- _c5: double (nullable = true)\n", - " |-- _c6: double (nullable = true)\n", - " |-- _c7: double (nullable = true)\n", - " |-- _c8: double (nullable = true)\n", - " |-- _c9: double (nullable = true)\n", - " |-- _c10: double (nullable = true)\n", - " |-- _c11: double (nullable = true)\n", - " |-- _c12: double (nullable = true)\n", - " |-- _c13: double (nullable = true)\n", - " |-- _c14: double (nullable = true)\n", - " |-- _c15: double (nullable = true)\n", - " |-- _c16: double (nullable = true)\n", - " |-- _c17: double (nullable = true)\n", - " |-- _c18: double (nullable = true)\n", - " |-- _c19: double (nullable = true)\n", - " |-- _c20: double (nullable = true)\n", - " |-- _c21: double (nullable = true)\n", - " |-- _c22: double (nullable = true)\n", - " |-- _c23: double (nullable = true)\n", - " |-- _c24: double (nullable = true)\n", - " |-- _c25: double (nullable = true)\n", - " |-- _c26: double (nullable = true)\n", - " |-- _c27: double (nullable = true)\n", - " |-- _c28: double (nullable = true)\n", - " |-- _c29: double (nullable = true)\n", - " |-- _c30: double (nullable = true)\n", - " |-- _c31: double (nullable = true)\n", - " |-- _c32: double (nullable = true)\n", - " |-- _c33: double (nullable = true)\n", - " |-- _c34: double (nullable = true)\n", - " |-- _c35: double (nullable = true)\n", - " |-- _c36: double (nullable = true)\n", - " |-- _c37: double (nullable = true)\n", - " |-- _c38: double (nullable = true)\n", - " |-- _c39: double (nullable = true)\n", - " |-- _c40: double (nullable = true)\n", - " |-- _c41: double (nullable = true)\n", - " |-- _c42: double (nullable = true)\n", - " |-- _c43: double (nullable = true)\n", - " |-- _c44: double (nullable = true)\n", - " |-- _c45: double (nullable = true)\n", - " |-- _c46: double (nullable = true)\n", - " |-- _c47: double (nullable = true)\n", - " |-- _c48: double (nullable = true)\n", - " |-- _c49: double (nullable = true)\n", - " |-- _c50: double (nullable = true)\n", - " |-- _c51: double (nullable = true)\n", - " |-- _c52: double (nullable = true)\n", - " |-- _c53: double (nullable = true)\n", - " |-- _c54: double (nullable = true)\n", - " |-- _c55: double (nullable = true)\n", - " |-- _c56: double (nullable = true)\n", - " |-- _c57: double (nullable = true)\n", - " |-- _c58: double (nullable = true)\n", - " |-- _c59: double (nullable = true)\n", - " |-- _c60: double (nullable = true)\n", - " |-- _c61: double (nullable = true)\n", - " |-- _c62: double (nullable = true)\n", - " |-- _c63: double (nullable = true)\n", - " |-- _c64: double (nullable = true)\n", - " |-- _c65: double (nullable = true)\n", - " |-- _c66: double (nullable = true)\n", - " |-- _c67: double (nullable = true)\n", - " |-- _c68: double (nullable = true)\n", - " |-- _c69: double (nullable = true)\n", - " |-- _c70: double (nullable = true)\n", - " |-- _c71: double (nullable = true)\n", - " |-- _c72: double (nullable = true)\n", - " |-- _c73: double (nullable = true)\n", - " |-- _c74: double (nullable = true)\n", - " |-- _c75: double (nullable = true)\n", - " |-- _c76: double (nullable = true)\n", - " |-- _c77: double (nullable = true)\n", - " |-- _c78: double (nullable = true)\n", - " |-- _c79: double (nullable = true)\n", - " |-- _c80: double (nullable = true)\n", - " |-- _c81: double (nullable = true)\n", - " |-- _c82: double (nullable = true)\n", - " |-- _c83: double (nullable = true)\n", - " |-- _c84: double (nullable = true)\n", - " |-- _c85: double (nullable = true)\n", - " |-- _c86: double (nullable = true)\n", - " |-- _c87: double (nullable = true)\n", - " |-- _c88: double (nullable = true)\n", - " |-- _c89: double (nullable = true)\n", - " |-- _c90: double (nullable = true)\n", - " |-- _c91: double (nullable = true)\n", - " |-- _c92: double (nullable = true)\n", - " |-- _c93: double (nullable = true)\n", - " |-- _c94: double (nullable = true)\n", - " |-- _c95: double (nullable = true)\n", - " |-- _c96: double (nullable = true)\n", - " |-- _c97: double (nullable = true)\n", - " |-- _c98: double (nullable = true)\n", - " |-- _c99: double (nullable = true)\n", - " |-- _c100: double (nullable = true)\n", - " |-- _c101: double (nullable = true)\n", - " |-- _c102: double (nullable = true)\n", - " |-- _c103: double (nullable = true)\n", - " |-- _c104: double (nullable = true)\n", - " |-- _c105: double (nullable = true)\n", - " |-- _c106: double (nullable = true)\n", - " |-- _c107: double (nullable = true)\n", - " |-- _c108: double (nullable = true)\n", - " |-- _c109: double (nullable = true)\n", - " |-- _c110: double (nullable = true)\n", - " |-- _c111: double (nullable = true)\n", - " |-- _c112: double (nullable = true)\n", - " |-- _c113: double (nullable = true)\n", - " |-- _c114: double (nullable = true)\n", - " |-- _c115: double (nullable = true)\n", - " |-- _c116: double (nullable = true)\n", - " |-- _c117: double (nullable = true)\n", - " |-- _c118: double (nullable = true)\n", - " |-- _c119: double (nullable = true)\n", - " |-- _c120: double (nullable = true)\n", - " |-- _c121: double (nullable = true)\n", - " |-- _c122: double (nullable = true)\n", - " |-- _c123: double (nullable = true)\n", - " |-- _c124: double (nullable = true)\n", - " |-- _c125: double (nullable = true)\n", - " |-- _c126: double (nullable = true)\n", - " |-- _c127: double (nullable = true)\n", - " |-- _c128: double (nullable = true)\n", - " |-- _c129: double (nullable = true)\n", - " |-- _c130: double (nullable = true)\n", - " |-- _c131: double (nullable = true)\n", - " |-- _c132: double (nullable = true)\n", - " |-- _c133: double (nullable = true)\n", - " |-- _c134: double (nullable = true)\n", - " |-- _c135: double (nullable = true)\n", - " |-- _c136: double (nullable = true)\n", - " |-- _c137: double (nullable = true)\n", - " |-- _c138: double (nullable = true)\n", - " |-- _c139: double (nullable = true)\n", - " |-- _c140: double (nullable = true)\n", - " |-- _c141: double (nullable = true)\n", - " |-- _c142: double (nullable = true)\n", - " |-- _c143: double (nullable = true)\n", - " |-- _c144: double (nullable = true)\n", - " |-- _c145: double (nullable = true)\n", - " |-- _c146: double (nullable = true)\n", - " |-- _c147: double (nullable = true)\n", - " |-- _c148: double (nullable = true)\n", - " |-- _c149: double (nullable = true)\n", - " |-- _c150: double (nullable = true)\n", - " |-- _c151: double (nullable = true)\n", - " |-- _c152: double (nullable = true)\n", - " |-- _c153: double (nullable = true)\n", - " |-- _c154: double (nullable = true)\n", - " |-- _c155: double (nullable = true)\n", - " |-- _c156: double (nullable = true)\n", - " |-- _c157: double (nullable = true)\n", - " |-- _c158: double (nullable = true)\n", - " |-- _c159: double (nullable = true)\n", - " |-- _c160: double (nullable = true)\n", - " |-- _c161: double (nullable = true)\n", - " |-- _c162: double (nullable = true)\n", - " |-- _c163: double (nullable = true)\n", - " |-- _c164: double (nullable = true)\n", - " |-- _c165: double (nullable = true)\n", - " |-- _c166: double (nullable = true)\n", - " |-- _c167: double (nullable = true)\n", - " |-- _c168: double (nullable = true)\n", - " |-- _c169: double (nullable = true)\n", - " |-- _c170: double (nullable = true)\n", - " |-- _c171: double (nullable = true)\n", - " |-- _c172: double (nullable = true)\n", - " |-- _c173: double (nullable = true)\n", - " |-- _c174: double (nullable = true)\n", - " |-- _c175: double (nullable = true)\n", - " |-- _c176: double (nullable = true)\n", - " |-- _c177: double (nullable = true)\n", - " |-- _c178: double (nullable = true)\n", - " |-- _c179: double (nullable = true)\n", - " |-- _c180: double (nullable = true)\n", - " |-- _c181: double (nullable = true)\n", - " |-- _c182: double (nullable = true)\n", - " |-- _c183: double (nullable = true)\n", - " |-- _c184: double (nullable = true)\n", - " |-- _c185: double (nullable = true)\n", - " |-- _c186: double (nullable = true)\n", - " |-- _c187: double (nullable = true)\n", - " |-- _c188: double (nullable = true)\n", - " |-- _c189: double (nullable = true)\n", - " |-- _c190: double (nullable = true)\n", - " |-- _c191: double (nullable = true)\n", - " |-- _c192: double (nullable = true)\n", - " |-- _c193: double (nullable = true)\n", - " |-- _c194: double (nullable = true)\n", - " |-- _c195: double (nullable = true)\n", - " |-- _c196: double (nullable = true)\n", - " |-- _c197: double (nullable = true)\n", - " |-- _c198: double (nullable = true)\n", - " |-- _c199: double (nullable = true)\n", - " |-- _c200: double (nullable = true)\n", - " |-- _c201: double (nullable = true)\n", - " |-- _c202: double (nullable = true)\n", - " |-- _c203: double (nullable = true)\n", - " |-- _c204: double (nullable = true)\n", - " |-- _c205: double (nullable = true)\n", - " |-- _c206: double (nullable = true)\n", - " |-- _c207: double (nullable = true)\n", - " |-- _c208: double (nullable = true)\n", - " |-- _c209: double (nullable = true)\n", - " |-- _c210: double (nullable = true)\n", - " |-- _c211: double (nullable = true)\n", - " |-- _c212: double (nullable = true)\n", - " |-- _c213: double (nullable = true)\n", - " |-- _c214: double (nullable = true)\n", - " |-- _c215: double (nullable = true)\n", - " |-- _c216: double (nullable = true)\n", - " |-- _c217: double (nullable = true)\n", - " |-- _c218: double (nullable = true)\n", - " |-- _c219: double (nullable = true)\n", - " |-- _c220: double (nullable = true)\n", - " |-- _c221: double (nullable = true)\n", - " |-- _c222: double (nullable = true)\n", - " |-- _c223: double (nullable = true)\n", - " |-- _c224: double (nullable = true)\n", - " |-- _c225: double (nullable = true)\n", - " |-- _c226: double (nullable = true)\n", - " |-- _c227: double (nullable = true)\n", - " |-- _c228: double (nullable = true)\n", - " |-- _c229: double (nullable = true)\n", - " |-- _c230: double (nullable = true)\n", - " |-- _c231: double (nullable = true)\n", - " |-- _c232: double (nullable = true)\n", - " |-- _c233: double (nullable = true)\n", - " |-- _c234: double (nullable = true)\n", - " |-- _c235: double (nullable = true)\n", - " |-- _c236: double (nullable = true)\n", - " |-- _c237: double (nullable = true)\n", - " |-- _c238: double (nullable = true)\n", - " |-- _c239: double (nullable = true)\n", - " |-- _c240: double (nullable = true)\n", - " |-- _c241: double (nullable = true)\n", - " |-- _c242: double (nullable = true)\n", - " |-- _c243: double (nullable = true)\n", - " |-- _c244: double (nullable = true)\n", - " |-- _c245: double (nullable = true)\n", - " |-- _c246: double (nullable = true)\n", - " |-- _c247: double (nullable = true)\n", - " |-- _c248: double (nullable = true)\n", - " |-- _c249: double (nullable = true)\n", - " |-- _c250: double (nullable = true)\n", - " |-- _c251: double (nullable = true)\n", - " |-- _c252: double (nullable = true)\n", - " |-- _c253: double (nullable = true)\n", - " |-- _c254: double (nullable = true)\n", - " |-- _c255: double (nullable = true)\n", - " |-- _c256: double (nullable = true)\n", - " |-- _c257: double (nullable = true)\n", - " |-- _c258: double (nullable = true)\n", - " |-- _c259: double (nullable = true)\n", - " |-- _c260: double (nullable = true)\n", - " |-- _c261: double (nullable = true)\n", - " |-- _c262: double (nullable = true)\n", - " |-- _c263: double (nullable = true)\n", - " |-- _c264: double (nullable = true)\n", - " |-- _c265: double (nullable = true)\n", - " |-- _c266: double (nullable = true)\n", - " |-- _c267: double (nullable = true)\n", - " |-- _c268: double (nullable = true)\n", - " |-- _c269: double (nullable = true)\n", - " |-- _c270: double (nullable = true)\n", - " |-- _c271: double (nullable = true)\n", - " |-- _c272: double (nullable = true)\n", - " |-- _c273: double (nullable = true)\n", - " |-- _c274: double (nullable = true)\n", - " |-- _c275: double (nullable = true)\n", - " |-- _c276: double (nullable = true)\n", - " |-- _c277: double (nullable = true)\n", - " |-- _c278: double (nullable = true)\n", - " |-- _c279: double (nullable = true)\n", - " |-- _c280: double (nullable = true)\n", - " |-- _c281: double (nullable = true)\n", - " |-- _c282: double (nullable = true)\n", - " |-- _c283: double (nullable = true)\n", - " |-- _c284: double (nullable = true)\n", - " |-- _c285: double (nullable = true)\n", - " |-- _c286: double (nullable = true)\n", - " |-- _c287: double (nullable = true)\n", - " |-- _c288: double (nullable = true)\n", - " |-- _c289: double (nullable = true)\n", - " |-- _c290: double (nullable = true)\n", - " |-- _c291: double (nullable = true)\n", - " |-- _c292: double (nullable = true)\n", - " |-- _c293: double (nullable = true)\n", - " |-- _c294: double (nullable = true)\n", - " |-- _c295: double (nullable = true)\n", - " |-- _c296: double (nullable = true)\n", - " |-- _c297: double (nullable = true)\n", - " |-- _c298: double (nullable = true)\n", - " |-- _c299: double (nullable = true)\n", - " |-- _c300: double (nullable = true)\n", - "\n" - ] - } - ], - "source": [ - "words_df.printSchema()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Fit the model\n", - "\n", - "The cosine distance is obtained with the inner product after normalizing all vectors to unit norm. This is faster than calculating the cosine distance directly" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "vector_assembler = VectorAssembler(inputCols=words_df.columns[1:], outputCol='features_as_vector')\n", - "\n", - "converter = VectorConverter(inputCol='features_as_vector', outputCol='features')\n", - "\n", - "normalizer = Normalizer(inputCol='features', outputCol='normalized_features')\n", - "\n", - "hnsw = HnswSimilarity(identifierCol='id', queryIdentifierCol='id', featuresCol='normalized_features', \n", - " distanceFunction='inner-product', m=48, ef=5, k=10, efConstruction=200, numPartitions=2, \n", - " excludeSelf=True, predictionCol='approximate', outputFormat='minimal')\n", - " \n", - "pipeline = Pipeline(stages=[vector_assembler, converter, normalizer, hnsw])\n", - "\n", - "model = pipeline.fit(words_df)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Transform the data\n", - "\n", - "Show the most similar words for a 1% sample" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", - "|id |approximate |\n", - "+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", - "|12,203 |[[12,207, 0.18658859], [12,208, 0.19709927], [9,203, 0.2014007], [7,846, 0.2022888], [9,534, 0.20384425], [10,870, 0.20649981], [9,084, 0.20763409], [9,088, 0.20896888], [10,864, 0.21028435], [5,415, 0.21110404]] |\n", - "|jobsPest |[[jobsQuality, 0.25308883], [jobsAviation, 0.26029062], [jobsPlumbing, 0.26665843], [jobsDiesel, 0.2683605], [jobsPackage, 0.26983005], [jobsLawn, 0.2711926], [jobsElectronics, 0.27372217], [jobsFire, 0.2752136], [jobsPediatrics, 0.2768286], [jobsPRN, 0.27690232]] |\n", - "|Fabric-Free |[[Contact4, 0.4873442], [RankingsCross-Platform, 0.48872536], [Twistys-, 0.49384797], [PatternsSewing, 0.49997622], [Results---------------------Trio, 0.50160384], [BootsUmbro, 0.5024311], [PostsCrazy, 0.50568044], [E-booksLiquid, 0.50938785], [ShoppingMainResourcesHome, 0.51089483], [IIStar, 0.5164968]] |\n", - "|MEDLEY |[[SMEDLEY, 0.33541584], [STATON, 0.4237159], [REDMAN, 0.4317326], [INSTRUMENTAL, 0.43596303], [LANDRUM, 0.43741482], [MELODY, 0.43790233], [EARLS, 0.44213486], [REEDY, 0.4467486], [LAKIN, 0.44767928], [CAPELLA, 0.44937247]] |\n", - "|Cherifa |[[Sherifa, 0.21229106], [Houria, 0.2890727], [Mariem, 0.29752994], [Nesrine, 0.30011368], [Souad, 0.30613774], [Youmna, 0.30868888], [Najat, 0.31098378], [Othmane, 0.3109982], [Aïcha, 0.3119046], [Mounira, 0.3127805]] |\n", - "|handling.The |[[handlingThe, 0.2960384], [HandlingThe, 0.348755], [acceleration.The, 0.38781983], [performance.The, 0.42423737], [grip.The, 0.4348228], [styling.The, 0.4419613], [reliability.The, 0.442101], [processing.The, 0.44788152], [quality.The, 0.4515201], [preparation.The, 0.4519968]] |\n", - "|feeding. |[[feeding-, 0.2464127], [breastfeeding., 0.24948114], [pumping., 0.2880875], [feeding.The, 0.2901587], [nursing., 0.30490905], [feed., 0.30967218], [eating., 0.33619535], [feeds., 0.3368044], [feeders., 0.33771157], [bfeeding, 0.34225684]] |\n", - "|INCREASE |[[DECREASE, 0.17253554], [INCREASES, 0.17714423], [INCREASED, 0.18786818], [INCREASING, 0.23869586], [DECREASED, 0.24525923], [DECREASING, 0.2651623], [DECREASES, 0.2724021], [REDUCES, 0.28129196], [REDUCE, 0.3192544], [increase, 0.33188963]] |\n", - "|Biedscheid |[[Vasturia, 0.34737885], [McAlarney, 0.3716166], [Abromaitis, 0.37181854], [Cougill, 0.38580287], [Harangody, 0.3867358], [Chrabascz, 0.3958456], [LeDay, 0.39698923], [Mosquera-Perea, 0.40215093], [Swopshire, 0.41246688], [Olekaibe, 0.41323322]] |\n", - "|StonesBlood |[[StoriesBlood, 0.2574374], [Pain-MigrainesBlood, 0.25935745], [DisordersBlood, 0.27709854], [toothbrushesBlood, 0.2863251], [VaginitisBlood, 0.35550547], [CancerUterine, 0.42338544], [citiesBloomfield, 0.43873364], [ScansBlood, 0.43958473], [DripPulmonary, 0.44069505], [Bubbles17Potty, 0.4551038]] |\n", - "|Dwyka |[[Ecca, 0.3334368], [Huab, 0.35957432], [Ugab, 0.37153518], [tillite, 0.38799363], [Neocomian, 0.3959421], [Palmwag, 0.39790404], [Tillite, 0.40405828], [calcrete, 0.40923548], [Kuiseb, 0.41476083], [Karroo, 0.41948062]] |\n", - "|Salmonoid |[[klempie, 0.17076665], [bergshark, 0.20234323], [robdylan, 0.20383084], [MorneN, 0.21372688], [JarsonX, 0.21756864], [rekinek, 0.22028464], [VinChainSaw, 0.22337723], [FireTheLooser, 0.24493623], [rhineshark, 0.2945698], [war1, 0.30755812]] |\n", - "|Thiruvananthapuram |[[Kozhikode, 0.113474965], [Trivandrum, 0.13436812], [Thrissur, 0.15189129], [Ernakulam, 0.18022627], [Thiruvanathapuram, 0.1967026], [Kannur, 0.1999836], [Kottayam, 0.20828074], [Alappuzha, 0.2144202], [Palakkad, 0.2157821], [Kerala, 0.2218926]] |\n", - "|411,000 |[[314,000, 0.18651259], [382,000, 0.19640517], [542,000, 0.19697678], [243,000, 0.19744879], [246,000, 0.19793433], [661,000, 0.19919282], [483,000, 0.19965571], [311,000, 0.19969779], [787,000, 0.19983482], [687,000, 0.19987303]] |\n", - "|renormalized |[[renormalize, 0.21896636], [renormalizable, 0.233994], [non-renormalizable, 0.24710214], [renormalization, 0.25644892], [renormalizing, 0.2630415], [perturbatively, 0.26749474], [renormalizability, 0.26908678], [renormalisation, 0.2788254], [one-particle, 0.2796554], [self-energy, 0.28007698]] |\n", - "|tic-tac-toe |[[Tic-Tac-Toe, 0.2675596], [Tic-tac-toe, 0.27591997], [tick-tack-toe, 0.31125695], [TicTacToe, 0.3837089], [game-board, 0.3997907], [pentomino, 0.40678132], [gameboard, 0.43625218], [pictionary, 0.4396732], [parcheesi, 0.44464827], [pentominoes, 0.44885635]] |\n", - "|Bordan |[[Casselle, 0.52586204], [Pershin, 0.5260427], [Vernik, 0.5282048], [Pawlo, 0.5297984], [Silley, 0.5305444], [Karidy, 0.53220797], [Lukich, 0.53514373], [Jarvin, 0.54001564], [Romanovitch, 0.5408964], [Vinnik, 0.542225]] |\n", - "|Trump-like |[[Trump-style, 0.24985683], [Trumpish, 0.32918006], [Trump-inspired, 0.3576998], [Trump-led, 0.3673811], [Trumpist, 0.36998725], [Trumpian, 0.38011009], [pussy-grabbing, 0.41307902], [Trump, 0.4192025], [Trump-, 0.43780744], [Obama-esque, 0.44226766]] |\n", - "|plural. |[[singular., 0.19377464], [nouns., 0.29209363], [adjectives., 0.33575755], [plural, 0.34617174], [pronouns., 0.3491757], [verbs., 0.35916436], [unnatural., 0.37773985], [singulars, 0.38103175], [punctuation., 0.38695288], [literal., 0.387487]] |\n", - "|2017AirAsia |[[2017Airport, 0.3248713], [2017Airtel, 0.32582766], [2017Airline, 0.33763862], [2017Air, 0.3726818], [2017Ural, 0.4164347], [2017JAL, 0.44614], [2017S7, 0.47999454], [2016Jetstar, 0.48473775], [2017Airfare, 0.4851548], [2017Lufthansa, 0.49247712]] |\n", - "|Tygh |[[Wamic, 0.46134633], [Seiad, 0.4715523], [Ohop, 0.47594672], [Walterville, 0.4771896], [Yoncalla, 0.48333758], [Ochoco, 0.4844129], [Westfir, 0.49013948], [Aumsville, 0.49770498], [Santiam, 0.49815756], [Dufur, 0.4989429]] |\n", - "|Corfield |[[Wakelam, 0.47453016], [Thursfield, 0.47644395], [Relton, 0.47709066], [Darnborough, 0.4790309], [Barker, 0.4950601], [Hunnisett, 0.5000353], [Clare, 0.50022745], [Loveday, 0.5014505], [McCullagh, 0.5016621], [Rushbrook, 0.5021677]] |\n", - "|off-topicness |[[off-topic, 0.39978695], [off-topics, 0.43869048], [any-other-ism, 0.46748966], [offtopic, 0.46778005], [threadjacking, 0.48088962], [noobishness, 0.48138767], [double-posting, 0.49016875], [thread-jacking, 0.49423033], [nitpickery, 0.49565685], [ailurophobia, 0.51669383]] |\n", - "|geometricians |[[geometrician, 0.2514078], [geometers, 0.28362888], [dialecticians, 0.3291492], [mathematicians, 0.35669494], [metaphysicians, 0.35981405], [metaphysicists, 0.36269522], [mechanicians, 0.36907917], [topographers, 0.37184918], [topologists, 0.37911403], [philosophers, 0.379712]] |\n", - "|probe-target |[[12-mer, 0.406626], [unhybridized, 0.4090193], [primer-template, 0.42718935], [intercalator, 0.42875183], [oligonucleotide, 0.42901552], [multimer, 0.42940372], [heteroduplexes, 0.43040854], [RNA-RNA, 0.43168443], [RNA-DNA, 0.43540668], [cross-hybridization, 0.43555927]] |\n", - "|Gilmores |[[Lorelai, 0.42448652], [Gallaghers, 0.4953835], [Garveys, 0.49884313], [Latimers, 0.5017904], [Bravermans, 0.510707], [Downeys, 0.51435757], [Huxtables, 0.5155351], [Dunphys, 0.52001923], [Donnellys, 0.5225915], [Merediths, 0.528318]] |\n", - "|re-entrant |[[reentrant, 0.15459806], [Reentrant, 0.4052536], [reentrance, 0.52774024], [radiused, 0.54940414], [lock-based, 0.550535], [reentrancy, 0.5559192], [fipple, 0.5608648], [inverted-V, 0.570162], [ramp-like, 0.5722155], [exception-handling, 0.5789175]] |\n", - "|retreatEntire |[[RetreatEntire, 0.12567973], [villaEntire, 0.1762833], [flatEntire, 0.18074965], [cottageEntire, 0.18832248], [quietEntire, 0.1888001], [loftEntire, 0.18973964], [houseEntire, 0.19146818], [paradiseEntire, 0.19647264], [condoEntire, 0.19726574], [aptEntire, 0.20008034]] |\n", - "|TDCC |[[C.I.P., 0.53586763], [Skindred, 0.5382503], [ZBR, 0.54636574], [HDIF, 0.55777407], [10CC, 0.5630076], [DTB, 0.5633284], [NDZ, 0.56419915], [BK3, 0.5651463], [HTDA, 0.5656252], [RXB, 0.5676259]] |\n", - "|effort-reward |[[U-curve, 0.3932373], [Work-family, 0.43020344], [work-home, 0.44562095], [effort-related, 0.4461912], [work-family, 0.45366144], [effort-based, 0.45776397], [family-work, 0.45854646], [JD-R, 0.46326238], [within-person, 0.46719086], [inverted-U, 0.47469378]] |\n", - "|--San |[[-San, 0.35914522], [--Los, 0.42107773], [.San, 0.42540854], [theSan, 0.43825322], [San, 0.44703954], [ex-San, 0.4533084], [--Library, 0.45704937], [--Santos, 0.45948273], [--Publishers, 0.46684676], [--Sandy, 0.46889085]] |\n", - "|self-deceived |[[self-deceiving, 0.22922957], [self-deluded, 0.24861062], [self-deceptive, 0.26969224], [deluded, 0.3012687], [self-delusional, 0.33168864], [deceived, 0.33604693], [self-deception, 0.34932947], [self-deluding, 0.36128765], [self-deceit, 0.3613189], [delusioned, 0.39297193]] |\n", - "|Gallivan |[[Kearns, 0.4857906], [Halloran, 0.4983012], [Scanlon, 0.5002695], [Dermody, 0.5029112], [Monihan, 0.51301515], [McNaboe, 0.5131852], [Hickey, 0.51328695], [Donahue, 0.51339185], [Conlin, 0.51695275], [Galivan, 0.51788837]] |\n", - "|knows |[[thinks, 0.21201801], [understands, 0.2261492], [KNOWS, 0.26246667], [wants, 0.27767152], [cares, 0.3033799], [believes, 0.32917213], [know, 0.34224766], [loves, 0.35013437], [realizes, 0.35853004], [sees, 0.38445485]] |\n", - "|time.com |[[healthland.time.com, 0.31079954], [www.time.com, 0.31666458], [business.time.com, 0.32041848], [Time.com, 0.32402343], [content.time.com, 0.3254763], [cbsnews.com, 0.3340934], [newsfeed.time.com, 0.3349049], [ti.me, 0.3349563], [forbes.com, 0.35376084], [huffingtonpost.com, 0.36056894]] |\n", - "|Honno |[[Lolfa, 0.39942592], [Gwasg, 0.44998926], [Mabinogion, 0.4501353], [Hiraeth, 0.45157075], [Shinchosha, 0.4525395], [Mslexia, 0.45460403], [Kobunsha, 0.455235], [Cymreig, 0.4695266], [Llandysul, 0.4730023], [Peirene, 0.4730605]] |\n", - "|295.8 |[[495.8, 0.17986465], [277.7, 0.18505645], [277.9, 0.18510514], [279.6, 0.18561989], [259.4, 0.19097263], [302.8, 0.19234681], [315.8, 0.19361043], [331.2, 0.19427317], [236.3, 0.19730085], [307.5, 0.19739395]] |\n", - "|Sát |[[Tất, 0.21651655], [Sư, 0.22710967], [Nữ, 0.23190814], [Tổ, 0.24468482], [Đạt, 0.24681973], [Tiếp, 0.24921513], [cơ, 0.24926937], [Để, 0.25090778], [Số, 0.251406], [tạo, 0.25180024]] |\n", - "|linux-firmware |[[gentoo-sources, 0.2748819], [kernel-devel, 0.2809406], [linux-headers, 0.2838297], [linux-ck, 0.2840634], [ati-drivers, 0.2860148], [alsa-driver, 0.2894156], [wireless-tools, 0.29206634], [linux-, 0.2965474], [x11-drivers, 0.29671365], [nvidia-utils, 0.29729158]] |\n", - "|USNWC |[[ORRC, 0.4963314], [UREC, 0.5053611], [Nantahala, 0.50777054], [Riverfest, 0.515401], [CRWC, 0.51723826], [Canoecopia, 0.5192678], [Whitewater, 0.52118385], [RRGCC, 0.52146745], [Paddlesport, 0.52546215], [SFWDA, 0.5266501]] |\n", - "|Mizuna |[[mizuna, 0.28863245], [Komatsuna, 0.36378044], [Tatsoi, 0.37227207], [Arugula, 0.39033133], [Radicchio, 0.4118532], [Lettuces, 0.41477716], [Mesclun, 0.43122846], [Shungiku, 0.435322], [Shishito, 0.44291615], [komatsuna, 0.45512182]] |\n", - "|produire |[[réduire, 0.30024856], [conduire, 0.30443156], [traduire, 0.31812984], [construire, 0.33397806], [quantité, 0.3619085], [comprendre, 0.36905253], [paraît, 0.37056595], [permettre, 0.37167805], [prenez, 0.37477142], [appeler, 0.3760041]] |\n", - "|Vishvamitra |[[Visvamitra, 0.2109701], [Vasishtha, 0.26352656], [Vishwamitra, 0.26676786], [Viswamitra, 0.2689237], [Vasistha, 0.28435642], [Durvasa, 0.29395294], [Parshurama, 0.3035227], [Angiras, 0.3126335], [Viśvāmitra, 0.31265825], [Lakshamana, 0.3264414]] |\n", - "|Alaris |[[Kodak, 0.54300165], [Sicor, 0.54351544], [Lumeris, 0.5489408], [Skyra, 0.5492325], [Agfa, 0.551931], [Rotech, 0.55331594], [Eclipsys, 0.5579692], [McKesson, 0.55963635], [Cohera, 0.56248426], [Skytron, 0.5640091]] |\n", - "|Glib |[[glib, 0.462829], [GLib, 0.48213995], [Gtk2, 0.4969753], [gobject, 0.530207], [GObject, 0.53471434], [Gtk, 0.54887956], [gtkmm, 0.54921776], [PyGTK, 0.5493102], [ocaml, 0.54951996], [pygtk, 0.5542872]] |\n", - "|.547 |[[.508, 0.13394785], [.468, 0.15660906], [.562, 0.15937334], [.459, 0.18963867], [.525, 0.19699472], [.570, 0.20013428], [.485, 0.20840245], [.472, 0.21170974], [.730, 0.2182309], [.515, 0.22378993]] |\n", - "|GLOSS |[[GLOSSY, 0.3989271], [MATTE, 0.497608], [SUAVE, 0.5245309], [LACQUER, 0.5251291], [SATIN, 0.5333786], [COLOURED, 0.542069], [PIGMENT, 0.5423312], [RIMMEL, 0.55206597], [COLOUR, 0.55828285], [ECRU, 0.55858755]] |\n", - "|200026th |[[200026, 0.18761241], [200023rd, 0.18821812], [200030th, 0.2668029], [199226th, 0.26723677], [200118th, 0.2696054], [199117th, 0.27169496], [199229th, 0.27382308], [199621st, 0.27706194], [199614th, 0.2772705], [199822nd, 0.27854365]] |\n", - "|war.When |[[war.As, 0.17672849], [war.While, 0.18371683], [war.But, 0.19178396], [war.And, 0.19600874], [war.Now, 0.19864732], [war.If, 0.2034421], [war.We, 0.20659214], [war.So, 0.20679426], [war.With, 0.21247083], [war.The, 0.2191723]] |\n", - "|Buto |[[Hermopolis, 0.44407785], [Nekhen, 0.45114744], [Anhur, 0.45315647], [Ombos, 0.45505714], [Waset, 0.4636187], [Ipy, 0.4712559], [Amenmose, 0.47222078], [Piankhi, 0.47222567], [Heracleopolis, 0.4738437], [Khnum, 0.47964293]] |\n", - "|agoEuro |[[agoEuropean, 0.14380676], [agoEurovision, 0.2558602], [agoExtreme, 0.3950531], [agoGirl, 0.4019295], [agoAsian, 0.40496123], [agoAnal, 0.40714467], [agoPetite, 0.41152078], [agoTight, 0.4160725], [agoYoung, 0.41789395], [agoErotic, 0.42120898]] |\n", - "|Chegg |[[Chegg.com, 0.28263116], [CourseSmart, 0.38691306], [InstaEDU, 0.38738722], [ValoreBooks, 0.44380403], [Textbooks.com, 0.45656365], [eTextbooks, 0.46011853], [e-textbook, 0.4664542], [chegg, 0.4726361], [Zinch, 0.4732185], [eTextbook, 0.47742146]] |\n", - "|yoyoing |[[yoyo, 0.51077217], [YoYoFactory, 0.63088036], [yo-yoed, 0.6395584], [degrassi, 0.6421582], [yo-yo-ing, 0.6464759], [168lbs, 0.6467501], [34lbs, 0.64780664], [weightclass, 0.65541697], [yo-yos, 0.65626645], [27lbs, 0.65648466]] |\n", - "|Psychotherapist |[[psychotherapist, 0.17837918], [Psychologist, 0.25483716], [Hypnotherapist, 0.26325], [Psychoanalyst, 0.2636032], [Therapist, 0.27796215], [Psychotherapy, 0.31045443], [Psychiatrist, 0.32049966], [Psychometrist, 0.34207618], [Neuropsychologist, 0.34612286], [hypnotherapist, 0.34960145]] |\n", - "|vsg |[[wls, 0.4130748], [pgc, 0.48014772], [l4l, 0.4840045], [cnf, 0.48659766], [ubr, 0.49018586], [bdd, 0.4916944], [sbar, 0.4919929], [flk, 0.49292332], [nsg, 0.49387354], [hba, 0.4952656]] |\n", - "|Genevan |[[Genevans, 0.29040444], [Genevese, 0.36319685], [Calvinist, 0.44801384], [Zwinglian, 0.45843595], [Vaudois, 0.4630267], [Helvetian, 0.46329576], [Genevois, 0.46833795], [Romanist, 0.47816467], [Zwingli, 0.48320568], [Castellio, 0.4834407]] |\n", - "|PackagesAll |[[beveragesAll, 0.36015803], [ImagesAll, 0.36921114], [ImagesAlthough, 0.45730734], [PackagesThe, 0.47000843], [PackagesWe, 0.48384595], [Packages-, 0.4847052], [PackagesAbout, 0.48732245], [Packages, 0.5011395], [KidsAll, 0.51534486], [OptionsAll, 0.5155984]] |\n", - "|Respiration |[[respiration, 0.24006212], [Respirations, 0.34009868], [respiration., 0.36672527], [Photosynthesis, 0.376208], [Ventilatory, 0.40083295], [Breathing, 0.4182803], [respirations, 0.42138022], [Physiology, 0.42760956], [Transpiration, 0.45514685], [Oxygenation, 0.46287602]] |\n", - "|KLINGER |[[KLING, 0.23572946], [RENNER, 0.3281743], [BOLLINGER, 0.34154397], [LINDER, 0.34276795], [KLINE, 0.34491885], [BECKMAN, 0.34566534], [SEITZ, 0.3464508], [MOSER, 0.35336167], [BOCK, 0.35609585], [MUELLER, 0.35618508]] |\n", - "|comunicato |[[dedicato, 0.2575605], [comunicativo, 0.2723022], [comunica, 0.2803359], [gestione, 0.2846164], [rapporto, 0.28886348], [affari, 0.28932178], [comunicativa, 0.29086506], [comunicazione, 0.2919432], [significato, 0.2928834], [controllo, 0.29501212]] |\n", - "|policer |[[policers, 0.24065274], [policy-map, 0.33857065], [class-default, 0.3778031], [class-map, 0.37915397], [CBWFQ, 0.39920253], [dscp, 0.43556625], [per-flow, 0.4382425], [match-any, 0.44513267], [subinterface, 0.45000672], [DSCP, 0.4596861]] |\n", - "|Harten |[[Heuven, 0.41431504], [Daele, 0.41610622], [Nunen, 0.42204052], [Biezen, 0.42723382], [Zanten, 0.4326527], [Welie, 0.43394887], [Langenhove, 0.44216907], [Meeteren, 0.4449361], [Dissel, 0.44622403], [Ostrand, 0.4474976]] |\n", - "|Fili |[[Moala, 0.52602714], [Fíli, 0.5298693], [Grima, 0.53000826], [Bofur, 0.5492251], [Lete, 0.54962724], [Casan, 0.5560886], [Filipo, 0.565763], [Salita, 0.56959856], [Tauriel, 0.571589], [Corone, 0.57198024]] |\n", - "|old.That |[[old.But, 0.1474247], [old.It, 0.16579741], [old.And, 0.17444688], [old.Now, 0.18936878], [old.So, 0.19835782], [old.This, 0.21017289], [old.Then, 0.21463454], [old.The, 0.22473359], [old.As, 0.25048524], [old.One, 0.26592392]] |\n", - "|FAUX |[[FUX, 0.3556295], [Faux, 0.3758254], [FauxNews, 0.43965852], [LIBERAL, 0.44661516], [MSLSD, 0.45354593], [PMSNBC, 0.46690792], [DUmmies, 0.47646612], [FOXNEWS, 0.4791808], [MSDNC, 0.49204332], [Lamestream, 0.49492323]] |\n", - "|adultery |[[adultry, 0.16772509], [infidelity, 0.25239778], [Adultery, 0.2583952], [fornication, 0.267451], [unchastity, 0.27988046], [adulterer, 0.29999435], [adulterous, 0.30117166], [ADULTERY, 0.30291516], [bigamy, 0.31043363], [unfaithfulness, 0.3113429]] |\n", - "|media.What |[[media.That, 0.26153243], [media.This, 0.26800585], [media.So, 0.27167422], [media.It, 0.3081441], [media.Now, 0.31988323], [media.But, 0.32106495], [media.One, 0.33115608], [media.And, 0.33856434], [media.How, 0.34665102], [media.If, 0.35073084]] |\n", - "|cor |[[dï, 0.3961203], [dÃ, 0.4175074], [cor-, 0.43493056], [porations, 0.43639678], [pulmonale, 0.43843025], [respondence, 0.45021278], [ruption, 0.45934278], [cor., 0.4631384], [cile, 0.47538304], ['cor, 0.49732816]] |\n", - "|Inq. |[[Inq, 0.40202188], [obiit, 0.44469577], [Glouc, 0.47012448], [cxix, 0.47587764], [Occ., 0.47661322], [Eliz., 0.48363084], [Knt., 0.48907453], [xix., 0.49697137], [xvi., 0.49911338], [Ev., 0.49912524]] |\n", - "|Bizer |[[Lauster, 0.44218892], [Gradinger, 0.46392375], [Uschold, 0.4644981], [Lebherz, 0.46534657], [Stefik, 0.467287], [Knemeyer, 0.4674167], [Roether, 0.47718024], [Judisch, 0.47799456], [Brumme, 0.48313397], [Faler, 0.4832179]] |\n", - "|in.Answered |[[inquiryAnswered, 0.26990676], [inquiry.Answered, 0.27426267], [products.Answered, 0.2887978], [.Answered, 0.32311904], [youAsked, 0.39150316], [you.Answered, 0.39469278], [vehicle.Answered, 0.40893382], [cc-3Q, 0.41504997], [cc-2Q, 0.41548425], [you.Asked, 0.42146426]] |\n", - "|USD400 |[[USD300, 0.20369935], [USD40, 0.2153914], [USD4, 0.2218684], [USD15, 0.24091631], [USD7, 0.24179614], [USD1, 0.2456035], [USD6, 0.25030327], [USD500, 0.25821334], [USD150, 0.26162666], [USD5, 0.2638967]] |\n", - "|marriageManchester |[[serviceManchester, 0.30249298], [surgeManchester, 0.34075987], [possibleManchester, 0.34607416], [peopleManchester, 0.35320836], [venueManchester, 0.36918032], [StyleManchester, 0.3761459], [PriceManchester, 0.37993687], [homeManchester, 0.38265467], [2017LifeManchester, 0.39223439], [GrandeManchester, 0.3959509]] |\n", - "|muhaddith |[[Muhaddith, 0.10724366], [sahabi, 0.30339968], [hadith, 0.30517715], [traditionist, 0.30940914], [al-Dhahabi, 0.3129747], [Rahimahullah, 0.3180018], [mujtahid, 0.3223698], [al-Qummi, 0.32311952], [ḥadīth, 0.32375914], [Al-Dhahabi, 0.3273775]] |\n", - "|SystemYou |[[systemYou, 0.259013], [SystemIf, 0.28937984], [SystemWe, 0.30691057], [SystemIt, 0.31427503], [SystemThere, 0.33225638], [SystemI, 0.34262645], [SystemTo, 0.34420937], [SystemWhen, 0.3580963], [SystemThe, 0.37157464], [SystemHow, 0.37612015]] |\n", - "|58Scammers |[[8Shocking, 0.3516931], [6239How, 0.35719866], [Printers11Do, 0.41092116], [188DuperDheats, 0.4200685], [Thread2Shocking, 0.48599648], [ThreadGet, 0.49720424], [CheckEducationLanguages, 0.50729126], [398BIC, 0.51352537], [, 0.5160079], [WorksTeamValuesManifestoJoin, 0.51632655]] |\n", - "|brutalization |[[brutalisation, 0.25391227], [dehumanization, 0.26310456], [brutality, 0.2811919], [terrorization, 0.31872302], [brutalities, 0.33425742], [subjugation, 0.34178543], [de-humanization, 0.34935534], [dehumanisation, 0.3623349], [enslavement, 0.37435597], [barbarity, 0.3812967]] |\n", - "|Kompleks |[[kompleks, 0.24991596], [Wisma, 0.3846205], [Pusat, 0.39512396], [Bangunan, 0.40657693], [Megah, 0.41615975], [Anjung, 0.41982526], [UITM, 0.42523694], [Pejabat, 0.43077368], [Lapangan, 0.43323845], [Gadong, 0.43326092]] |\n", - "|RFLP |[[RFLPs, 0.15574348], [PCR-RFLP, 0.2023353], [RAPD, 0.25871873], [ARDRA, 0.2668311], [RAPD-PCR, 0.26967585], [AFLP, 0.2698329], [PFGE, 0.29038012], [RAPDs, 0.2907418], [rep-PCR, 0.2966656], [ribotyping, 0.30552095]] |\n", - "|Bolpur |[[Balurghat, 0.29711658], [Bongaon, 0.30289835], [Baharampur, 0.3033607], [Bankura, 0.30567527], [Krishnanagar, 0.3158574], [Arambagh, 0.32335418], [Jalpaiguri, 0.33144796], [Basirhat, 0.3317582], [Arambag, 0.33280414], [Joynagar, 0.33501327]] |\n", - "|Half-marathon |[[Half-Marathon, 0.17345399], [half-marathon, 0.18152589], [Marathon, 0.24437416], [marathon, 0.25628233], [4-miler, 0.26497716], [Marathon-, 0.26737952], [5-miler, 0.27653486], [10-miler, 0.28508943], [halfmarathon, 0.29552543], [Ultra-Marathon, 0.29792058]] |\n", - "|Pashana |[[gandha, 0.42467618], [Dhathri, 0.43248785], [Gandhak, 0.4325605], [Somanatha, 0.43534237], [Gandha, 0.43653828], [Kedara, 0.43804038], [Jatamansi, 0.4435336], [Thailam, 0.44382733], [Auroshikha, 0.4487217], [Viswanathar, 0.45192194]] |\n", - "|Mountainview |[[Hillcrest, 0.3884545], [Hillview, 0.41177368], [Hillside, 0.42438316], [Hilltop, 0.43153], [Lakeview, 0.43236703], [Oakridge, 0.4456876], [Lakeside, 0.44626737], [Cedarview, 0.4476506], [Fairview, 0.44865865], [Pinebrook, 0.45652014]] |\n", - "|Strawberries- |[[strawberries-, 0.32558924], [Potatoes-, 0.3760872], [Cookies-, 0.37909764], [Fruit-, 0.38373864], [Chips-, 0.40149504], [Chocolate-, 0.4140039], [Cupcakes-, 0.4153443], [Seeds-, 0.4234345], [Snack-, 0.42579532], [Butter-, 0.43139273]] |\n", - "|kluttz |[[enniss, 0.46748388], [blackmer, 0.5220524], [jno, 0.52520955], [n.c, 0.5256657], [ijl, 0.53179747], [jji, 0.53273344], [.'., 0.5336398], [howerton, 0.5381024], [krider, 0.53874934], [i.l, 0.5402815]] |\n", - "|yBC |[[IBLF, 0.37565815], [e4e, 0.39584428], [Flytxt, 0.40649498], [BMGI, 0.40887225], [VMob, 0.41047966], [INXPO, 0.4135257], [TCUK, 0.41742367], [vcfo, 0.4181037], [Expion, 0.4204862], [Mubaloo, 0.42112082]] |\n", - "|InstituteCentral |[[InstituteUniversity, 0.39424777], [CollegeCentral, 0.40582484], [UniversityInstitute, 0.40619195], [InstituteNational, 0.41587758], [CollegeNorthwest, 0.41652536], [CollegeNorthwestern, 0.41781962], [CollegeNortheast, 0.41827053], [UniversityTrinity, 0.42332816], [CollegeHeritage, 0.42716473], [UniversityNorthwestern, 0.42863095]] |\n", - "|deal.Get |[[promotion.Get, 0.33704406], [merchant.Get, 0.34280938], [offer.Get, 0.36109442], [retailer.Get, 0.42141682], [discount.Get, 0.4313135], [discounts.Get, 0.45440573], [today.Get, 0.45729834], [prices.Get, 0.46747506], [site.Get, 0.4841087], [DealOnline, 0.48490703]] |\n", - "|oopss |[[ooppss, 0.17881781], [hhahaha, 0.21567678], [hahhaa, 0.2196294], [hhehe, 0.22064424], [hahaaaa, 0.22428185], [haahaha, 0.22476071], [hehehee, 0.22777593], [ooooops, 0.22884655], [sighh, 0.23201782], [hahaaa, 0.23202652]] |\n", - "|Enterworks |[[Enterp, 0.46053088], [Interworks, 0.47492576], [Teligent, 0.48701757], [TEOCO, 0.49184293], [TriVascular, 0.49956352], [Answerthink, 0.5032177], [BroadVision, 0.50519246], [Authentidate, 0.5058122], [deconsolidated, 0.5071093], [Asyst, 0.5076766]] |\n", - "|Indrapuri |[[Usmanpura, 0.31736088], [Ramchandrapur, 0.31939864], [Chandrapura, 0.3231414], [Brahmapuri, 0.3340426], [Shakarpur, 0.3377335], [Thakurganj, 0.33962828], [Durgapura, 0.3397761], [Gopalpura, 0.34235972], [Shankarpur, 0.34763026], [Alipura, 0.34859347]] |\n", - "|GuideCar |[[guideCar, 0.27506894], [GuideCanberra, 0.4653917], [GuideContact, 0.52747715], [GuideCleaning, 0.532438], [GuideContent, 0.5397519], [GuideClosed, 0.54824674], [GuideChris, 0.54955006], [Guide3, 0.55196565], [GuideComplimentary, 0.5528095], [GuideCruise, 0.5557381]] |\n", - "|35ViewThe |[[25ViewThe, 0.09465182], [45ViewThe, 0.11915678], [65ViewThe, 0.12135941], [85ViewThe, 0.12278879], [55ViewThe, 0.12575859], [90ViewThe, 0.13606101], [75ViewThe, 0.13861567], [40ViewThe, 0.14391416], [50ViewThe, 0.14580822], [80ViewThe, 0.14717925]] |\n", - "|her.Was |[[him.Was, 0.13783264], [anyone.Was, 0.18888414], [better.Was, 0.19741523], [me.Was, 0.19872057], [to.Was, 0.20547539], [doctor.Was, 0.2220295], [them.Was, 0.22396213], [be.Was, 0.22786236], [us.Was, 0.22904474], [care.Was, 0.23101228]] |\n", - "|Tche |[[Wuta, 0.40115637], [Awng, 0.41493613], [Ndombe, 0.42099226], [Boua, 0.43270826], [Bô, 0.4331659], [Hô, 0.43375325], [Kawng, 0.43583435], [Cuoi, 0.43640697], [Pya, 0.44193262], [Buyu, 0.4431575]] |\n", - "|Hefferman |[[McFaden, 0.44177163], [Tortorello, 0.4560455], [Barrone, 0.45779955], [Schraufnagel, 0.4695629], [Battson, 0.47042966], [McLaughlin, 0.4706053], [Greaver, 0.4834932], [McLauglin, 0.48358828], [Stauble, 0.48369557], [Feeney, 0.48422974]] |\n", - "|3-card |[[three-card, 0.15661728], [two-card, 0.25869733], [2-card, 0.26184517], [four-card, 0.27489406], [5-card, 0.27581465], [six-card, 0.2803859], [4-card, 0.286667], [6-card, 0.28720605], [7-card, 0.2961431], [five-card, 0.2975436]] |\n", - "|oh-so-amazing |[[not-so-amazing, 0.26233566], [oh-so-awesome, 0.3253194], [oh-so-lovely, 0.3915581], [oh-so-fabulous, 0.3927827], [ah-mazing, 0.4291818], [AH-mazing, 0.4304353], [oh-so-wonderful, 0.43389642], [oh-so-pretty, 0.43789083], [oh-so-adorable, 0.44143748], [uh-mazing, 0.44544053]] |\n", - "|SatisfactionLocation5.0Value4.0Maintenance4.0Office|[[SatisfactionLocation5.0Value4.0Maintenance5.0Office, 0.02729261], [SatisfactionLocation4.0Value4.0Maintenance4.0Office, 0.039210796], [SatisfactionLocation5.0Value5.0Maintenance5.0Office, 0.068318546], [5.0Value, 0.3597989], [PlayersD1D2ConferencesRankingsTransfersDraftsBlogAbout, 0.3624913], [NewsReviewsColumnsFeaturesShootoutsVideosBasher, 0.36544526], [AssemblyDeliveryValue, 0.37287682], [agency0Services0Reviews0Works0Team, 0.37342477], [ReviewsCamerasLensesLightingAudioPost, 0.37444127], [Value4.6, 0.3747821]]|\n", - "|Peculiars |[[peculiars, 0.29335696], [Paranormals, 0.52860606], [Chilbury, 0.5301968], [Unspeakables, 0.5332626], [Blackthorns, 0.535407], [Incorrigibles, 0.5515528], [Serrailler, 0.5537924], [Spyness, 0.55487597], [Demonglass, 0.5554242], [Coven, 0.55711365]] |\n", - "+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", - "only showing top 100 rows\n", - "\n" - ] - } - ], - "source": [ - "words_df_sample = words_df.sample(0.01)\n", - "\n", - "model.transform(words_df_sample) \\\n", - " .show(100, False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Save the model" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "model.save(\"/tmp/model\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "PySpark", - "language": "python", - "name": "pyspark" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [] - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/hnswlib-examples/hnswlib-examples-pyspark-luigi/README.md b/hnswlib-examples/hnswlib-examples-pyspark-luigi/README.md deleted file mode 100644 index fde8f75a..00000000 --- a/hnswlib-examples/hnswlib-examples-pyspark-luigi/README.md +++ /dev/null @@ -1,19 +0,0 @@ -hnswlib-examples-pyspark-luigi -============================== - -Create a new virtual environment luigi-venv: - - python3 -m venv luigi-venv - -And activate the newly created virtual environment: - - . luigi-venv/bin/activate - -Install dependencies: - - pip install wheel luigi requests - -To execute the task you created, run the following command: - - python -m luigi --module flow Query --local-scheduler - diff --git a/hnswlib-examples/hnswlib-examples-pyspark-luigi/bruteforce_index.py b/hnswlib-examples/hnswlib-examples-pyspark-luigi/bruteforce_index.py deleted file mode 100644 index 94848e57..00000000 --- a/hnswlib-examples/hnswlib-examples-pyspark-luigi/bruteforce_index.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- - -import argparse - -from pyspark.ml import Pipeline -from pyspark.sql import SparkSession -from pyspark_hnsw.knn import * -from pyspark_hnsw.linalg import Normalizer - - -def main(spark): - parser = argparse.ArgumentParser(description='Construct brute force index') - parser.add_argument('--input', type=str) - parser.add_argument('--model', type=str) - parser.add_argument('--output', type=str) - parser.add_argument('--num_partitions', type=int) - - args = parser.parse_args() - - normalizer = Normalizer(inputCol='features', outputCol='normalized_features') - - bruteforce = BruteForceSimilarity(identifierCol='id', featuresCol='normalized_features', - distanceFunction='inner-product', numPartitions=args.num_partitions) - - pipeline = Pipeline(stages=[normalizer, bruteforce]) - - index_items = spark.read.parquet(args.input) - - model = pipeline.fit(index_items) - - model.write().overwrite().save(args.output) - - -if __name__ == '__main__': - main(SparkSession.builder.getOrCreate()) diff --git a/hnswlib-examples/hnswlib-examples-pyspark-luigi/convert.py b/hnswlib-examples/hnswlib-examples-pyspark-luigi/convert.py deleted file mode 100644 index f855f3c0..00000000 --- a/hnswlib-examples/hnswlib-examples-pyspark-luigi/convert.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- - -import argparse - -from pyspark.ml.feature import VectorAssembler -from pyspark.sql import SparkSession -from pyspark_hnsw.conversion import VectorConverter - - -def main(spark): - parser = argparse.ArgumentParser(description='Convert input file to parquet') - parser.add_argument('--input', type=str) - parser.add_argument('--output', type=str) - args = parser.parse_args() - - words_df = spark.read \ - .option('inferSchema', 'true') \ - .option('quote', '\u0000') \ - .option('delimiter', ' ') \ - .csv(args.input) \ - .withColumnRenamed('_c0', 'id') - - vector_assembler = VectorAssembler(inputCols=words_df.columns[1:], outputCol='features_as_vector') - - converter = VectorConverter(inputCol='features_as_vector', outputCol='features', outputType='array') - - converter.transform(vector_assembler.transform(words_df)) \ - .select('id', 'features') \ - .write \ - .parquet(args.output) - - -if __name__ == "__main__": - main(SparkSession.builder.getOrCreate()) diff --git a/hnswlib-examples/hnswlib-examples-pyspark-luigi/evaluate_performance.py b/hnswlib-examples/hnswlib-examples-pyspark-luigi/evaluate_performance.py deleted file mode 100644 index d90aaeb6..00000000 --- a/hnswlib-examples/hnswlib-examples-pyspark-luigi/evaluate_performance.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- coding: utf-8 -*- - -import argparse - -from pyspark.ml import PipelineModel -from pyspark.sql import SparkSession -from pyspark_hnsw.evaluation import KnnSimilarityEvaluator - - -def main(spark): - parser = argparse.ArgumentParser(description='Evaluate performance of the index') - parser.add_argument('--hnsw_model', type=str) - parser.add_argument('--bruteforce_model', type=str) - parser.add_argument('--input', type=str) - parser.add_argument('--output', type=str) - parser.add_argument('--k', type=int) - parser.add_argument('--ef', type=int) - parser.add_argument('--fraction', type=float) - parser.add_argument('--seed', type=int) - - args = parser.parse_args() - - sample_query_items = spark.read.parquet(args.input).sample(False, args.fraction, args.seed) - - hnsw_model = PipelineModel.read().load(args.hnsw_model) - - hnsw_stage = hnsw_model.stages[-1] - hnsw_stage.setEf(args.ef) - hnsw_stage.setK(args.k) - hnsw_stage.setPredictionCol('approximate') - hnsw_stage.setOutputFormat('full') - - bruteforce_model = PipelineModel.read().load(args.bruteforce_model) - - bruteforce_stage = bruteforce_model.stages[-1] - bruteforce_stage.setK(args.k) - bruteforce_stage.setPredictionCol('exact') - bruteforce_stage.setOutputFormat('full') - - sample_results = bruteforce_model.transform(hnsw_model.transform(sample_query_items)) - - evaluator = KnnSimilarityEvaluator(approximateNeighborsCol='approximate', exactNeighborsCol='exact') - - accuracy = evaluator.evaluate(sample_results) - - spark.createDataFrame([[accuracy]], ['accuracy']).repartition(1).write.mode('overwrite').csv(args.output) - - -if __name__ == '__main__': - main(SparkSession.builder.getOrCreate()) diff --git a/hnswlib-examples/hnswlib-examples-pyspark-luigi/flow.py b/hnswlib-examples/hnswlib-examples-pyspark-luigi/flow.py deleted file mode 100644 index cd4cbe7e..00000000 --- a/hnswlib-examples/hnswlib-examples-pyspark-luigi/flow.py +++ /dev/null @@ -1,327 +0,0 @@ -# -*- coding: utf-8 -*- - -import urllib.request -import shutil - -import luigi -from luigi import FloatParameter, IntParameter, LocalTarget, Parameter -from luigi.contrib.spark import SparkSubmitTask -from luigi.format import Nop -from luigi.contrib.external_program import ExternalProgramTask -# from luigi.contrib.hdfs import HdfsFlagTarget -# from luigi.contrib.s3 import S3FlagTarget - - -class Download(luigi.Task): - """ - Download the input dataset. - """ - - url = Parameter(default='https://nlp.stanford.edu/data/glove.42B.300d.zip') - - def output(self): - return LocalTarget('/tmp/dataset.zip', format=Nop) - - def run(self): - # noinspection PyTypeChecker - with urllib.request.urlopen(self.url) as response: - with self.output().open('wb') as f: - shutil.copyfileobj(response, f) - - -class Unzip(ExternalProgramTask): - """ - Unzip the input dataset. - """ - - def requires(self): - return Download() - - def output(self): - return LocalTarget('/tmp/dataset', format=Nop) - - def program_args(self): - self.output().makedirs() - return ['unzip', - '-u', - '-q', - '-d', self.output().path, - self.input().path] - - -class Convert(SparkSubmitTask): - """ - Convert the input dataset to parquet. - """ - - # master = 'yarn' - master = 'local[*]' - - deploy_mode = 'client' - - driver_memory = '2g' - - # executor_memory = '4g' - - num_executors = IntParameter(default=2) - - name = 'Convert' - - app = 'convert.py' - - packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] - - def requires(self): - return Unzip() - - def app_options(self): - return [ - "--input", self.input().path, - "--output", self.output().path - ] - - def output(self): - # return HdfsFlagTarget('/tmp/vectors_parquet') - # return S3FlagTarget('/tmp/vectors_parquet') - return LocalTarget('/tmp/vectors_parquet', format=Nop) - - -class HnswIndex(SparkSubmitTask): - """ - Construct the hnsw index and persists it to disk. - """ - - # master = 'yarn' - master = 'local[*]' - - deploy_mode = 'client' - - # driver_memory = '2g' - driver_memory = '24g' - - # executor_memory = '12g' - - num_executors = IntParameter(default=2) - - executor_cores = IntParameter(default=2) - - name = 'Hnsw index' - - app = 'hnsw_index.py' - - packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] - - m = IntParameter(default=16) - - ef_construction = IntParameter(default=200) - - @property - def conf(self): - return {'spark.dynamicAllocation.enabled': 'false', - 'spark.speculation': 'false', - 'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', - 'spark.kryo.registrator': 'com.github.jelmerk.spark.HnswLibKryoRegistrator', - 'spark.task.cpus': str(self.executor_cores), - 'spark.task.maxFailures': '1', - 'spark.scheduler.minRegisteredResourcesRatio': '1.0', - 'spark.scheduler.maxRegisteredResourcesWaitingTime': '3600s', - 'spark.hnswlib.settings.index.cache_folder': '/tmp'} - - def requires(self): - return Convert() - - def app_options(self): - return [ - '--input', self.input().path, - '--output', self.output().path, - '--m', self.m, - '--ef_construction', self.ef_construction, - '--num_partitions', str(self.num_executors) - ] - - def output(self): - # return HdfsFlagTarget('/tmp/hnsw_index') - # return S3FlagTarget('/tmp/hnsw_index') - return LocalTarget('/tmp/hnsw_index', format=Nop) - - -class Query(SparkSubmitTask): - """ - Query the constructed knn index. - """ - - # master = 'yarn' - master = 'local[*]' - - deploy_mode = 'client' - - # driver_memory = '2g' - driver_memory = '24g' - - # executor_memory = '10g' - - num_executors = IntParameter(default=4) - - executor_cores = IntParameter(default=2) - - packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] - - name = 'Query index' - - app = 'query.py' - - k = IntParameter(default=10) - - ef = IntParameter(default=100) - - num_replicas = IntParameter(default=1) - - @property - def conf(self): - return {'spark.dynamicAllocation.enabled': 'false', - 'spark.speculation': 'false', - 'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', - 'spark.kryo.registrator': 'com.github.jelmerk.spark.HnswLibKryoRegistrator', - 'spark.task.cpus': str(self.executor_cores), - 'spark.task.maxFailures': '1', - 'spark.scheduler.minRegisteredResourcesRatio': '1.0', - 'spark.scheduler.maxRegisteredResourcesWaitingTime': '3600s'} - - def requires(self): - return {'vectors': Convert(), - 'index': HnswIndex()} - - def app_options(self): - return [ - '--input', self.input()['vectors'].path, - '--model', self.input()['index'].path, - '--output', self.output().path, - '--ef', self.ef, - '--k', self.k, - '--num_replicas', self.num_replicas - ] - - def output(self): - # return HdfsFlagTarget('/tmp/query_results') - # return S3FlagTarget('/tmp/query_results') - return LocalTarget('/tmp/query_results') - - -class BruteForceIndex(SparkSubmitTask): - """ - Construct the brute force index and persists it to disk. - """ - - # master = 'yarn' - master = 'local[*]' - - deploy_mode = 'client' - - # driver_memory = '2g' - driver_memory = '24g' - - # executor_memory = '12g' - - num_executors = IntParameter(default=2) - - executor_cores = IntParameter(default=2) - - name = 'Brute force index' - - app = 'bruteforce_index.py' - - packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] - - @property - def conf(self): - return {'spark.dynamicAllocation.enabled': 'false', - 'spark.speculation': 'false', - 'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', - 'spark.kryo.registrator': 'com.github.jelmerk.spark.HnswLibKryoRegistrator', - 'spark.task.cpus': str(self.executor_cores), - 'spark.task.maxFailures': '1', - 'spark.scheduler.minRegisteredResourcesRatio': '1.0', - 'spark.scheduler.maxRegisteredResourcesWaitingTime': '3600s', - 'spark.hnswlib.settings.index.cache_folder': '/tmp'} - - def requires(self): - return Convert() - - def app_options(self): - return [ - '--input', self.input().path, - '--output', self.output().path, - '--num_partitions', str(self.num_executors) - ] - - def output(self): - # return HdfsFlagTarget('/tmp/brute_force_index') - # return S3FlagTarget('/tmp/brute_force_index') - return LocalTarget('/tmp/brute_force_index', format=Nop) - - -class Evaluate(SparkSubmitTask): - """ - Evaluate the accuracy of the approximate k-nearest neighbors model vs a bruteforce baseline. - """ - - # master = 'yarn' - master = 'local[*]' - - deploy_mode = 'client' - - # driver_memory = '2g' - driver_memory = '24g' - - # executor_memory = '12g' - - num_executors = IntParameter(default=2) - - executor_cores = IntParameter(default=2) - - k = IntParameter(default=10) - - ef = IntParameter(default=100) - - fraction = FloatParameter(default=0.0001) - - seed = IntParameter(default=123) - - name = 'Evaluate performance' - - app = 'evaluate_performance.py' - - packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0'] - - @property - def conf(self): - return {'spark.dynamicAllocation.enabled': 'false', - 'spark.speculation': 'false', - 'spark.serializer': 'org.apache.spark.serializer.KryoSerializer', - 'spark.kryo.registrator': 'com.github.jelmerk.spark.HnswLibKryoRegistrator', - 'spark.task.cpus': str(self.executor_cores), - 'spark.task.maxFailures': '1', - 'spark.scheduler.minRegisteredResourcesRatio': '1.0', - 'spark.scheduler.maxRegisteredResourcesWaitingTime': '3600s'} - - def requires(self): - return {'vectors': Convert(), - 'hnsw_index': HnswIndex(), - 'bruteforce_index': BruteForceIndex()} - - def app_options(self): - return [ - '--input', self.input()['vectors'].path, - '--output', self.output().path, - '--hnsw_model', self.input()['hnsw_index'].path, - '--bruteforce_model', self.input()['bruteforce_index'].path, - '--ef', self.ef, - '--k', self.k, - '--seed', self.seed, - '--fraction', self.fraction, - ] - - def output(self): - # return HdfsFlagTarget('/tmp/metrics') - # return S3FlagTarget('/tmp/metrics') - return LocalTarget('/tmp/metrics', format=Nop) diff --git a/hnswlib-examples/hnswlib-examples-pyspark-luigi/hnsw_index.py b/hnswlib-examples/hnswlib-examples-pyspark-luigi/hnsw_index.py deleted file mode 100644 index 5a1df028..00000000 --- a/hnswlib-examples/hnswlib-examples-pyspark-luigi/hnsw_index.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- - -import argparse - -from pyspark.ml import Pipeline -from pyspark.sql import SparkSession -from pyspark_hnsw.knn import * -from pyspark_hnsw.linalg import Normalizer - - -def main(spark): - parser = argparse.ArgumentParser(description='Construct hnsw index') - parser.add_argument('--input', type=str) - parser.add_argument('--output', type=str) - parser.add_argument('--m', type=int) - parser.add_argument('--ef_construction', type=int) - parser.add_argument('--num_partitions', type=int) - - args = parser.parse_args() - - normalizer = Normalizer(inputCol='features', outputCol='normalized_features') - - hnsw = HnswSimilarity(identifierCol='id', queryIdentifierCol='id', featuresCol='normalized_features', - distanceFunction='inner-product', m=args.m, efConstruction=args.ef_construction, - numPartitions=args.num_partitions, excludeSelf=True, outputFormat='minimal') - - pipeline = Pipeline(stages=[normalizer, hnsw]) - - index_items = spark.read.parquet(args.input) - - model = pipeline.fit(index_items) - - model.write().overwrite().save(args.output) - - -if __name__ == '__main__': - main(SparkSession.builder.getOrCreate()) diff --git a/hnswlib-examples/hnswlib-examples-pyspark-luigi/query.py b/hnswlib-examples/hnswlib-examples-pyspark-luigi/query.py deleted file mode 100644 index fbc5a859..00000000 --- a/hnswlib-examples/hnswlib-examples-pyspark-luigi/query.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- - -import argparse - -from pyspark.ml import PipelineModel -from pyspark.sql import SparkSession - - -def main(spark): - parser = argparse.ArgumentParser(description='Query index') - parser.add_argument('--input', type=str) - parser.add_argument('--model', type=str) - parser.add_argument('--output', type=str) - parser.add_argument('--k', type=int) - parser.add_argument('--ef', type=int) - parser.add_argument('--num_replicas', type=int) - - args = parser.parse_args() - - model = PipelineModel.read().load(args.model) - - hnsw_stage = model.stages[-1] - hnsw_stage.setEf(args.ef) - hnsw_stage.setK(args.k) - hnsw_stage.setNumReplicas(args.num_replicas) - - query_items = spark.read.parquet(args.input) - - results = model.transform(query_items) - - results.write.mode('overwrite').json(args.output) - - -if __name__ == '__main__': - main(SparkSession.builder.getOrCreate()) diff --git a/hnswlib-spark/README.md b/hnswlib-spark/README.md deleted file mode 100644 index f7281cde..00000000 --- a/hnswlib-spark/README.md +++ /dev/null @@ -1,151 +0,0 @@ -[![Maven Central](https://maven-badges.herokuapp.com/maven-central/com.github.jelmerk/hnswlib-spark_2.3_2.11/badge.svg)](https://maven-badges.herokuapp.com/maven-central/com.github.jelmerk/hnswlib-spark_2.3_2.11) [![Scaladoc](https://javadoc.io/badge2/com.github.jelmerk/hnswlib-spark_2.3_2.11/javadoc.svg)](https://javadoc.io/doc/com.github.jelmerk/hnswlib-spark_2.3_2.11) - - -hnswlib-spark -============= - -[Apache spark](https://spark.apache.org/) integration for hnswlib. - -About ------ - -The easiest way to use this library with spark is to simply collect your data on the driver node and index it there. -This does mean you'll have to allocate a lot of cores and memory to the driver. - -The alternative to this is to use this module to shard the index across multiple executors -and parallelize the indexing / querying. This may be faster if you have many executors at your disposal and is -appropriate when your dataset does not fit in the driver memory - -Distance functions optimized for use with sparse vectors will automatically be selected base on the input type - -Setup ------ - -Find the package appropriate for your spark setup - -| | Scala 2.11 | Scala 2.12 | -| ----------- |-------------------------------------------------|-------------------------------------------------| -| Spark 2.3.x | com.github.jelmerk:hnswlib-spark_2.3_2.11:1.1.0 | | -| Spark 2.4.x | com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0 | com.github.jelmerk:hnswlib-spark_2.4_2.12:1.1.0 | -| Spark 3.0.x | | com.github.jelmerk:hnswlib-spark_3.0_2.12:1.1.0 | -| Spark 3.1.x | | com.github.jelmerk:hnswlib-spark_3.1_2.12:1.1.0 | -| Spark 3.2.x | | com.github.jelmerk:hnswlib-spark_3.2_2.12:1.1.0 | -| Spark 3.3.x | | com.github.jelmerk:hnswlib-spark_3.3_2.12:1.1.0 | - -Pass this as an argument to spark - - --packages 'com.github.jelmerk:hnswlib-spark_2.3_2.11:1.1.0' - -Example usage -------------- - -Basic: - -```scala -import com.github.jelmerk.spark.knn.hnsw.HnswSimilarity - -val hnsw = new HnswSimilarity() - .setIdentifierCol("id") - .setQueryIdentifierCol("id") - .setFeaturesCol("features") - .setNumPartitions(2) - .setM(48) - .setEf(5) - .setEfConstruction(200) - .setK(200) - .setDistanceFunction("cosine") - .setExcludeSelf(true) - -val model = hnsw.fit(indexItems) - -model.transform(indexItems).write.parquet("/path/to/output") -``` - -Advanced: - -```scala -import org.apache.spark.ml.Pipeline - -import com.github.jelmerk.spark.knn.bruteforce.BruteForceSimilarity -import com.github.jelmerk.spark.knn.evaluation.KnnSimilarityEvaluator -import com.github.jelmerk.spark.knn.hnsw.HnswSimilarity -import com.github.jelmerk.spark.linalg.Normalizer -import com.github.jelmerk.spark.conversion.VectorConverter - -// often it is acceptable to use float instead of double precision. -// this uses less memory and will be faster - -val converter = new VectorConverter() - .setInputCol("featuresAsMlLibVector") - .setOutputCol("features") - -// The cosine distance is obtained with the inner product after normalizing all vectors to unit norm -// this is much faster than calculating the cosine distance directly - -val normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normalizedFeatures") - -val hnsw = new HnswSimilarity() - .setIdentifierCol("id") - .setQueryIdentifierCol("id") - .setFeaturesCol("normalizedFeatures") - .setNumPartitions(2) - .setK(200) - .setSimilarityThreshold(0.4) - .setDistanceFunction("inner-product") - .setPredictionCol("approximate") - .setExcludeSelf(true) - .setM(48) - .setEfConstruction(200) - -val bruteForce = new BruteForceSimilarity() - .setIdentifierCol(hnsw.getIdentifierCol) - .setQueryIdentifierCol(hnsw.getQueryIdentifierCol) - .setFeaturesCol(hnsw.getFeaturesCol) - .setNumPartitions(2) - .setK(hnsw.getK) - .setSimilarityThreshold(hnsw.getSimilarityThreshold) - .setDistanceFunction(hnsw.getDistanceFunction) - .setPredictionCol("exact") - .setExcludeSelf(hnsw.getExcludeSelf) - -val pipeline = new Pipeline() - .setStages(Array(converter, normalizer, hnsw, bruteForce)) - -val model = pipeline.fit(indexItems) - -// computing the exact similarity is expensive so only take a small sample -val queryItems = indexItems.sample(0.01) - -val output = model.transform(queryItems) - -val evaluator = new KnnSimilarityEvaluator() - .setApproximateNeighborsCol("approximate") - .setExactNeighborsCol("exact") - -val accuracy = evaluator.evaluate(output) - -println(s"Accuracy: $accuracy") - -// save the model -model.write.overwrite.save("/path/to/model") -``` - -Suggested configuration ------------------------ - -- set `executor.instances` to the same value as the numPartitions property of your Hnsw instance -- set `spark.executor.cores` to as high a value as feasible on your executors while not making your jobs impossible to schedule -- set `spark.task.cpus` to the same value as `spark.executor.cores` -- set `spark.scheduler.minRegisteredResourcesRatio` to `1.0` -- set `spark.scheduler.maxRegisteredResourcesWaitingTime` to `3600` -- set `spark.speculation` to `false` -- set `spark.dynamicAllocation.enabled` to `false` -- set `spark.task.maxFailures` to `1` -- set `spark.driver.memory`: to some arbitrary low value for instance `2g` will do because the model does not run on the driver -- set `spark.executor.memory`: to a value appropriate to the size of your data, typically this will be a large value -- set `spark.yarn.executor.memoryOverhead` to a value higher than `executorMemory * 0.10` if you get the "Container killed by YARN for exceeding memory limits" error -- set `spark.hnswlib.settings.index.cache_folder` to a folder with plenty of space that you can write to. Defaults to /tmp - -Note that as it stands increasing the number of partitions will speed up fitting the model but not querying the model. The only way to speed up querying is by increasing the number of replicas diff --git a/hnswlib-spark/run-pyspark-tests.sh b/hnswlib-spark/run-pyspark-tests.sh deleted file mode 100755 index 56579de6..00000000 --- a/hnswlib-spark/run-pyspark-tests.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env bash - -set -e - -SPARK_VERSION=$1 -PYTHON_VERSION=$2 - -# add python sources on the path -export PYTHONPATH=src/main/python - -# unset SPARK_HOME or it will use whatever is configured on the host system instead of the pip packages -unset SPARK_HOME - -# create a virtual environment - -eval "$PYTHON_VERSION -m venv "target/spark-$SPARK_VERSION-venv"" -source "target/spark-$SPARK_VERSION-venv/bin/activate" - -# install packages -pip install pytest==7.4.3 -pip install 'pyspark[ml]'=="$SPARK_VERSION" - -# run unit tests -pytest --junitxml=target/test-reports/TEST-python.xml \ No newline at end of file diff --git a/hnswlib-spark/src/main/python/__init__.py b/hnswlib-spark/src/main/python/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/hnswlib-spark/src/main/python/com/__init__.py b/hnswlib-spark/src/main/python/com/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/hnswlib-spark/src/main/python/com/github/__init__.py b/hnswlib-spark/src/main/python/com/github/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/conversion/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/conversion/__init__.py deleted file mode 100644 index d7734861..00000000 --- a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/conversion/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -import sys -import pyspark_hnsw.conversion - -sys.modules['com.github.jelmerk.spark.conversion'] = pyspark_hnsw.conversion diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/bruteforce/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/bruteforce/__init__.py deleted file mode 100644 index 0393eccc..00000000 --- a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/bruteforce/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -import sys -import pyspark_hnsw.knn - -sys.modules['com.github.jelmerk.spark.knn.bruteforce'] = pyspark_hnsw.knn diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/evaluation/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/evaluation/__init__.py deleted file mode 100644 index c26a2188..00000000 --- a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/evaluation/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -import sys -import pyspark_hnsw.evaluation - -sys.modules['com.github.jelmerk.spark.knn.evaluation'] = pyspark_hnsw.evaluation diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/hnsw/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/hnsw/__init__.py deleted file mode 100644 index b3e87b00..00000000 --- a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/knn/hnsw/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -import sys -import pyspark_hnsw.knn - -sys.modules['com.github.jelmerk.spark.knn.hnsw'] = pyspark_hnsw.knn \ No newline at end of file diff --git a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/linalg/__init__.py b/hnswlib-spark/src/main/python/com/github/jelmerk/spark/linalg/__init__.py deleted file mode 100644 index 6d0fdbe4..00000000 --- a/hnswlib-spark/src/main/python/com/github/jelmerk/spark/linalg/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -import sys -import pyspark_hnsw.linalg - -sys.modules['com.github.jelmerk.spark.linalg'] = pyspark_hnsw.linalg \ No newline at end of file diff --git a/hnswlib-spark/src/main/python/pyspark_hnsw/__init__.py b/hnswlib-spark/src/main/python/pyspark_hnsw/__init__.py deleted file mode 100644 index 1e2d3598..00000000 --- a/hnswlib-spark/src/main/python/pyspark_hnsw/__init__.py +++ /dev/null @@ -1,189 +0,0 @@ -import subprocess -import threading -from pyspark.sql import SparkSession -from pyspark.conf import SparkConf -from pyspark.context import SparkContext -from pyspark.java_gateway import launch_gateway - - -def start(spark23=False, - spark24=False, - spark31=False, - memory="16G", - cache_folder="/tmp", - real_time_output=False, - output_level=1): - """Starts a PySpark instance with default parameters for Hnswlib. - - The default parameters would result in the equivalent of: - - .. code-block:: python - :param spark23: start Hnswlib on Apache Spark 2.3.x - :param spark24: start Hnswlib on Apache Spark 2.4.x - :param spark31: start Hnswlib on Apache Spark 3.1.x - :param memory: set driver memory for SparkSession - :param output_level: int, optional Output level for logs, by default 1 - :param real_time_output: - :substitutions: - - SparkSession.builder \\ - .appName("Hnswlib") \\ - .master("local[*]") \\ - .config("spark.driver.memory", "16G") \\ - .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \\ - .config("spark.kryo.registrator", "com.github.jelmerk.spark.HnswLibKryoRegistrator") \\ - .config("spark.jars.packages", "com.github.jelmerk:hnswlib-spark_3.0_2.12:|release|") \\ - .config("spark.hnswlib.settings.index.cache_folder", "/tmp") \\ - .getOrCreate() - - - Parameters - ---------- - - spark23 : bool, optional - Whether to use the Spark 2.3.x version of Hnswlib, by default False - spark24 : bool, optional - Whether to use the Spark 2.4.x version of Hnswlib, by default False - spark31 : bool, optional - Whether to use the Spark 3.1.x version of Hnswlib, by default False - memory : str, optional - How much memory to allocate for the Spark driver, by default "16G" - real_time_output : bool, optional - Whether to output in real time, by default False - output_level : int, optional - Output level for logs, by default 1 - - Returns - ------- - :class:`SparkSession` - The initiated Spark session. - - """ - current_version = "1.1.0" - - class HnswlibConfig: - - def __init__(self): - self.master = "local[*]" - self.app_name = "Hnswlib" - self.serializer = "org.apache.spark.serializer.KryoSerializer" - self.registrator = "com.github.jelmerk.spark.HnswLibKryoRegistrator" - # Hnswlib on Apache Spark 3.2.x - - # Hnswlib on Apache Spark 3.0.x/3.1.x - self.maven_spark = "com.github.jelmerk:hnswlib-spark_3.1_2.12:{}".format(current_version) - # Hnswlib on Apache Spark 2.4.x - self.maven_spark24 = "com.github.jelmerk:hnswlib-spark_2.4_2.12:{}".format(current_version) - # Hnswlib on Apache Spark 2.3.x - self.maven_spark23 = "com.github.jelmerk:hnswlib-spark_2.3_2.11:{}".format(current_version) - - def start_without_realtime_output(): - builder = SparkSession.builder \ - .appName(spark_nlp_config.app_name) \ - .master(spark_nlp_config.master) \ - .config("spark.driver.memory", memory) \ - .config("spark.serializer", spark_nlp_config.serializer) \ - .config("spark.kryo.registrator", spark_nlp_config.registrator) \ - .config("spark.hnswlib.settings.index.cache_folder", cache_folder) - - if spark23: - builder.config("spark.jars.packages", spark_nlp_config.maven_spark23) - elif spark24: - builder.config("spark.jars.packages", spark_nlp_config.maven_spark24) - else: - builder.config("spark.jars.packages", spark_nlp_config.maven_spark) - - return builder.getOrCreate() - - def start_with_realtime_output(): - - class SparkWithCustomGateway: - - def __init__(self): - spark_conf = SparkConf() - spark_conf.setAppName(spark_nlp_config.app_name) - spark_conf.setMaster(spark_nlp_config.master) - spark_conf.set("spark.driver.memory", memory) - spark_conf.set("spark.serializer", spark_nlp_config.serializer) - spark_conf.set("spark.kryo.registrator", spark_nlp_config.registrator) - spark_conf.set("spark.jars.packages", spark_nlp_config.maven_spark) - spark_conf.set("spark.hnswlib.settings.index.cache_folder", cache_folder) - - # Make the py4j JVM stdout and stderr available without buffering - popen_kwargs = { - 'stdout': subprocess.PIPE, - 'stderr': subprocess.PIPE, - 'bufsize': 0 - } - - # Launch the gateway with our custom settings - self.gateway = launch_gateway(conf=spark_conf, popen_kwargs=popen_kwargs) - self.process = self.gateway.proc - # Use the gateway we launched - spark_context = SparkContext(gateway=self.gateway) - self.spark_session = SparkSession(spark_context) - - self.out_thread = threading.Thread(target=self.output_reader) - self.error_thread = threading.Thread(target=self.error_reader) - self.std_background_listeners() - - def std_background_listeners(self): - self.out_thread.start() - self.error_thread.start() - - def output_reader(self): - for line in iter(self.process.stdout.readline, b''): - print('{0}'.format(line.decode('utf-8')), end='') - - def error_reader(self): - RED = '\033[91m' - RESET = '\033[0m' - for line in iter(self.process.stderr.readline, b''): - if output_level == 0: - print(RED + '{0}'.format(line.decode('utf-8')) + RESET, end='') - else: - # output just info - pass - - def shutdown(self): - self.spark_session.stop() - self.gateway.shutdown() - self.process.communicate() - - self.out_thread.join() - self.error_thread.join() - - return SparkWithCustomGateway() - - spark_nlp_config = HnswlibConfig() - - if real_time_output: - if spark23 or spark24: - spark_session = start_without_realtime_output() - return spark_session - else: - # Available from Spark 3.0.x - class SparkRealTimeOutput: - - def __init__(self): - self.__spark_with_custom_gateway = start_with_realtime_output() - self.spark_session = self.__spark_with_custom_gateway.spark_session - - def shutdown(self): - self.__spark_with_custom_gateway.shutdown() - - return SparkRealTimeOutput() - else: - spark_session = start_without_realtime_output() - return spark_session - - -def version(): - """Returns the current Hnswlib version. - - Returns - ------- - str - The current Hnswlib version. - """ - return '1.1.0' diff --git a/hnswlib-spark/src/main/python/pyspark_hnsw/conversion.py b/hnswlib-spark/src/main/python/pyspark_hnsw/conversion.py deleted file mode 100644 index 1dcfd366..00000000 --- a/hnswlib-spark/src/main/python/pyspark_hnsw/conversion.py +++ /dev/null @@ -1,35 +0,0 @@ -from pyspark.ml.param.shared import * -from pyspark.ml.wrapper import JavaTransformer -from pyspark.ml.util import JavaMLReadable, JavaMLWritable -from pyspark.mllib.common import inherit_doc -from pyspark import keyword_only - -__all__ = ['VectorConverter'] - -@inherit_doc -class VectorConverter(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): - """ - Converts the input vector to a vector of another type. - """ - - outputType = Param(Params._dummy(), "outputType", "type of vector to produce. one of array, array, vector", - typeConverter=TypeConverters.toString) - - @keyword_only - def __init__(self, inputCol="input", outputCol="output", outputType="array"): - """ - __init__(self, inputCol="input", outputCol="output", outputType="array") - """ - super(VectorConverter, self).__init__() - self._java_obj = self._new_java_obj("com.github.jelmerk.spark.conversion.VectorConverter", self.uid) - kwargs = self._input_kwargs - self.setParams(**kwargs) - - @keyword_only - def setParams(self, inputCol="input", outputCol="output", outputType="array"): - """ - setParams(self, inputCol="input", outputCol="output", outputType="array") - Sets params for this VectorConverter. - """ - kwargs = self._input_kwargs - return self._set(**kwargs) diff --git a/hnswlib-spark/src/main/python/pyspark_hnsw/evaluation.py b/hnswlib-spark/src/main/python/pyspark_hnsw/evaluation.py deleted file mode 100644 index 2f556693..00000000 --- a/hnswlib-spark/src/main/python/pyspark_hnsw/evaluation.py +++ /dev/null @@ -1,54 +0,0 @@ -from pyspark.ml.evaluation import JavaEvaluator -from pyspark.ml.param.shared import * -from pyspark.mllib.common import inherit_doc -from pyspark import keyword_only -from pyspark.ml.util import JavaMLReadable, JavaMLWritable - -__all__ = ['KnnSimilarityEvaluator'] - -@inherit_doc -class KnnSimilarityEvaluator(JavaEvaluator, JavaMLReadable, JavaMLWritable): - """ - Evaluate the performance of a knn model. - """ - @keyword_only - def __init__(self, approximateNeighborsCol="approximateNeighbors", exactNeighborsCol="exactNeighbors"): - super(JavaEvaluator, self).__init__() - self._java_obj = self._new_java_obj("com.github.jelmerk.spark.knn.evaluation.KnnSimilarityEvaluator", self.uid) - - self.approximateNeighborsCol = Param(self, "approximateNeighborsCol", "the column name for the row identifier") - self.exactNeighborsCol = Param(self, "exactNeighborsCol", "the column name for the vector") - - self._setDefault(approximateNeighborsCol="approximateNeighbors", exactNeighborsCol="exactNeighbors") - - kwargs = self._input_kwargs - self.setParams(**kwargs) - - def getApproximateNeighborsCol(self): - """ - Gets the value of approximateNeighborsCol or its default value. - """ - return self.getOrDefault(self.approximateNeighborsCol) - - def setApproximateNeighborsCol(self, value): - """ - Sets the value of :py:attr:`approximateNeighborsCol`. - """ - return self._set(approximateNeighborsCol=value) - - def getExactNeighborsCol(self): - """ - Gets the value of exactNeighborsCol or its default value. - """ - return self.getOrDefault(self.exactNeighborsCol) - - def setExactNeighborsCol(self, value): - """ - Sets the value of :py:attr:`exactNeighborsCol`. - """ - return self._set(exactNeighborsCol=value) - - @keyword_only - def setParams(self, approximateNeighborsCol="approximateNeighbors", exactNeighborsCol="exactNeighbors"): - kwargs = self._input_kwargs - return self._set(**kwargs) diff --git a/hnswlib-spark/src/main/python/pyspark_hnsw/knn.py b/hnswlib-spark/src/main/python/pyspark_hnsw/knn.py deleted file mode 100644 index 6380abcc..00000000 --- a/hnswlib-spark/src/main/python/pyspark_hnsw/knn.py +++ /dev/null @@ -1,570 +0,0 @@ -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams -from pyspark.ml.param.shared import * -from pyspark.mllib.common import inherit_doc -from pyspark import keyword_only -from pyspark.ml.util import JavaMLReadable, JavaMLWritable, MLReader, _jvm - -__all__ = ['HnswSimilarity', 'HnswSimilarityModel', 'BruteForceSimilarity', 'BruteForceSimilarityModel', 'HnswLibMLReader'] - -class HnswLibMLReader(MLReader): - - """ - Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types - """ - - def __init__(self, clazz, java_class): - self._clazz = clazz - self._jread = self._load_java_obj(java_class).read() - - def load(self, path): - """Load the ML instance from the input path.""" - java_obj = self._jread.load(path) - return self._clazz._from_java(java_obj) - - @classmethod - def _load_java_obj(cls, java_class): - """Load the peer Java object of the ML instance.""" - java_obj = _jvm() - for name in java_class.split("."): - java_obj = getattr(java_obj, name) - return java_obj - -@inherit_doc -class _KnnModelParams(HasFeaturesCol, HasPredictionCol): - """ - Params for knn models. - """ - - queryIdentifierCol = Param(Params._dummy(), "queryIdentifierCol", "the column name for the query identifier", - typeConverter=TypeConverters.toString) - - queryPartitionsCol = Param(Params._dummy(), "queryPartitionsCol", "the column name for the query partitions", - typeConverter=TypeConverters.toString) - - parallelism = Param(Params._dummy(), "parallelism", "number of threads to use", typeConverter=TypeConverters.toInt) - - k = Param(Params._dummy(), "k", "number of neighbors to find", typeConverter=TypeConverters.toInt) - - numReplicas = Param(Params._dummy(), "numReplicas", "number of index replicas to create when querying", typeConverter=TypeConverters.toInt) - - excludeSelf = Param(Params._dummy(), "excludeSelf", "whether to include the row identifier as a candidate neighbor", - typeConverter=TypeConverters.toBoolean) - - similarityThreshold = Param(Params._dummy(), "similarityThreshold", - "do not return neighbors further away than this distance", - typeConverter=TypeConverters.toFloat) - - outputFormat = Param(Params._dummy(), "outputFormat", "output format, one of full, minimal", - typeConverter=TypeConverters.toString) - - def getQueryIdentifierCol(self): - """ - Gets the value of queryIdentifierCol or its default value. - """ - return self.getOrDefault(self.queryIdentifierCol) - - def getQueryPartitionsCol(self): - """ - Gets the value of queryPartitionsCol or its default value. - """ - return self.getOrDefault(self.queryPartitionsCol) - - def getParallelism(self): - """ - Gets the value of parallelism or its default value. - """ - return self.getOrDefault(self.parallelism) - - def getK(self): - """ - Gets the value of k or its default value. - """ - return self.getOrDefault(self.k) - - def getExcludeSelf(self): - """ - Gets the value of excludeSelf or its default value. - """ - return self.getOrDefault(self.excludeSelf) - - def getSimilarityThreshold(self): - """ - Gets the value of similarityThreshold or its default value. - """ - return self.getOrDefault(self.similarityThreshold) - - def getOutputFormat(self): - """ - Gets the value of outputFormat or its default value. - """ - return self.getOrDefault(self.outputFormat) - - def getNumReplicas(self): - """ - Gets the value of numReplicas or its default value. - """ - return self.getOrDefault(self.numReplicas) - - -@inherit_doc -class _KnnParams(_KnnModelParams): - """ - Params for knn algorithms. - """ - - identifierCol = Param(Params._dummy(), "identifierCol", "the column name for the row identifier", - typeConverter=TypeConverters.toString) - - partitionCol = Param(Params._dummy(), "partitionCol", "the column name for the partition", - typeConverter=TypeConverters.toString) - - initialModelPath = Param(Params._dummy(), "initialModelPath", "path to the initial model", - typeConverter=TypeConverters.toString) - - numPartitions = Param(Params._dummy(), "numPartitions", "number of partitions", typeConverter=TypeConverters.toInt) - - distanceFunction = Param(Params._dummy(), "distanceFunction", - "distance function, one of bray-curtis, canberra, cosine, correlation, " + - "euclidean, inner-product, manhattan or the fully qualified classname " + - "of a distance function", typeConverter=TypeConverters.toString) - - def getIdentifierCol(self): - """ - Gets the value of identifierCol or its default value. - """ - return self.getOrDefault(self.identifierCol) - - def getPartitionCol(self): - """ - Gets the value of partitionCol or its default value. - """ - return self.getOrDefault(self.partitionCol) - - def getInitialModelPath(self): - """ - Gets the value of initialModelPath or its default value. - """ - return self.getOrDefault(self.initialModelPath) - - def getNumPartitions(self): - """ - Gets the value of numPartitions or its default value. - """ - return self.getOrDefault(self.numPartitions) - - def getDistanceFunction(self): - """ - Gets the value of distanceFunction or its default value. - """ - return self.getOrDefault(self.distanceFunction) - - -@inherit_doc -class _HnswModelParams(_KnnModelParams): - """ - Params for :py:class:`Hnsw` and :py:class:`HnswModel`. - """ - - ef = Param(Params._dummy(), "ef", "size of the dynamic list for the nearest neighbors (used during the search)", - typeConverter=TypeConverters.toInt) - - def getEf(self): - """ - Gets the value of ef or its default value. - """ - return self.getOrDefault(self.ef) - - -@inherit_doc -class _HnswParams(_HnswModelParams, _KnnParams): - """ - Params for :py:class:`Hnsw`. - """ - - m = Param(Params._dummy(), "m", "number of bi-directional links created for every new element during construction", - typeConverter=TypeConverters.toInt) - - efConstruction = Param(Params._dummy(), "efConstruction", - "has the same meaning as ef, but controls the index time / index precision", - typeConverter=TypeConverters.toInt) - - def getM(self): - """ - Gets the value of m or its default value. - """ - return self.getOrDefault(self.m) - - def getEfConstruction(self): - """ - Gets the value of efConstruction or its default value. - """ - return self.getOrDefault(self.efConstruction) - - -@inherit_doc -class BruteForceSimilarity(JavaEstimator, _KnnParams, JavaMLReadable, JavaMLWritable): - """ - Exact nearest neighbour search. - """ - - @keyword_only - def __init__(self, identifierCol="id", partitionCol=None, queryIdentifierCol=None, queryPartitionsCol=None, - parallelism= None, featuresCol="features", predictionCol="prediction", numPartitions=1, numReplicas=0, - k=5, distanceFunction="cosine", excludeSelf=False, similarityThreshold=-1.0, outputFormat="full", - initialModelPath=None): - super(BruteForceSimilarity, self).__init__() - self._java_obj = self._new_java_obj("com.github.jelmerk.spark.knn.bruteforce.BruteForceSimilarity", self.uid) - - self._setDefault(identifierCol="id", numPartitions=1, numReplicas=0, k=5, distanceFunction="cosine", - excludeSelf=False, similarityThreshold=-1.0, outputFormat="full") - - kwargs = self._input_kwargs - self.setParams(**kwargs) - - def setIdentifierCol(self, value): - """ - Sets the value of :py:attr:`identifierCol`. - """ - return self._set(identifierCol=value) - - def setQueryIdentifierCol(self, value): - """ - Sets the value of :py:attr:`queryIdentifierCol`. - """ - return self._set(queryIdentifierCol=value) - - def setPartitionCol(self, value): - """ - Sets the value of :py:attr:`partitionCol`. - """ - return self._set(partitionCol=value) - - def setQueryPartitionsCol(self, value): - """ - Sets the value of :py:attr:`queryPartitionsCol`. - """ - return self._set(queryPartitionsCol=value) - - def setParallelism(self, value): - """ - Sets the value of :py:attr:`parallelism`. - """ - return self._set(parallelism=value) - - def setNumPartitions(self, value): - """ - Sets the value of :py:attr:`numPartitions`. - """ - return self._set(numPartitions=value) - - def setNumReplicas(self, value): - """ - Sets the value of :py:attr:`numReplicas`. - """ - return self._set(numReplicas=value) - - def setK(self, value): - """ - Sets the value of :py:attr:`k`. - """ - return self._set(k=value) - - def setDistanceFunction(self, value): - """ - Sets the value of :py:attr:`distanceFunction`. - """ - return self._set(distanceFunction=value) - - def setExcludeSelf(self, value): - """ - Sets the value of :py:attr:`excludeSelf`. - """ - return self._set(excludeSelf=value) - - def setSimilarityThreshold(self, value): - """ - Sets the value of :py:attr:`similarityThreshold`. - """ - return self._set(similarityThreshold=value) - - def setOutputFormat(self, value): - """ - Sets the value of :py:attr:`outputFormat`. - """ - return self._set(outputFormat=value) - - def setInitialModelPath(self, value): - """ - Sets the value of :py:attr:`initialModelPath`. - """ - return self._set(initialModelPath=value) - - @keyword_only - def setParams(self, identifierCol="id", queryIdentifierCol=None, queryPartitionsCol=None, parallelism=None, - featuresCol="features", predictionCol="prediction",numPartitions=1, numReplicas=0, k=5, - distanceFunction="cosine", excludeSelf=False, similarityThreshold=-1.0, outputFormat="full", - initialModelPath=None): - kwargs = self._input_kwargs - return self._set(**kwargs) - - def _create_model(self, java_model): - return BruteForceSimilarityModel(java_model) - - -class BruteForceSimilarityModel(JavaModel, _KnnModelParams, JavaMLReadable, JavaMLWritable): - """ - Model fitted by BruteForce. - """ - - _classpath_model = 'com.github.jelmerk.spark.knn.bruteforce.BruteForceSimilarityModel' - - def setQueryIdentifierCol(self, value): - """ - Sets the value of :py:attr:`queryIdentifierCol`. - """ - return self._set(queryIdentifierCol=value) - - def setQueryPartitionsCol(self, value): - """ - Sets the value of :py:attr:`queryPartitionsCol`. - """ - return self._set(queryPartitionsCol=value) - - def setParallelism(self, value): - """ - Sets the value of :py:attr:`parallelism`. - """ - return self._set(parallelism=value) - - def setK(self, value): - """ - Sets the value of :py:attr:`k`. - """ - return self._set(k=value) - - def setExcludeSelf(self, value): - """ - Sets the value of :py:attr:`excludeSelf`. - """ - return self._set(excludeSelf=value) - - def setSimilarityThreshold(self, value): - """ - Sets the value of :py:attr:`similarityThreshold`. - """ - return self._set(similarityThreshold=value) - - def setOutputFormat(self, value): - """ - Sets the value of :py:attr:`outputFormat`. - """ - return self._set(outputFormat=value) - - def setNumReplicas(self, value): - """ - Sets the value of :py:attr:`numReplicas`. - """ - return self._set(numReplicas=value) - - @classmethod - def read(cls): - return HnswLibMLReader(cls, cls._classpath_model) - - -@inherit_doc -class HnswSimilarity(JavaEstimator, _HnswParams, JavaMLReadable, JavaMLWritable): - """ - Approximate nearest neighbour search. - """ - - @keyword_only - def __init__(self, identifierCol="id", queryIdentifierCol=None, queryPartitionsCol=None, parallelism=None, - featuresCol="features", predictionCol="prediction", m=16, ef=10, efConstruction=200, numPartitions=1, - numReplicas=0, k=5, distanceFunction="cosine", excludeSelf=False, similarityThreshold=-1.0, - outputFormat="full", initialModelPath=None): - super(HnswSimilarity, self).__init__() - self._java_obj = self._new_java_obj("com.github.jelmerk.spark.knn.hnsw.HnswSimilarity", self.uid) - - self._setDefault(identifierCol="id", m=16, ef=10, efConstruction=200, numPartitions=1, numReplicas=0, k=5, - distanceFunction="cosine", excludeSelf=False, similarityThreshold=-1.0, outputFormat="full", - initialModelPath=None) - - kwargs = self._input_kwargs - self.setParams(**kwargs) - - def setIdentifierCol(self, value): - """ - Sets the value of :py:attr:`identifierCol`. - """ - return self._set(identifierCol=value) - - def setQueryIdentifierCol(self, value): - """ - Sets the value of :py:attr:`queryIdentifierCol`. - """ - return self._set(queryIdentifierCol=value) - - def setPartitionCol(self, value): - """ - Sets the value of :py:attr:`partitionCol`. - """ - return self._set(partitionCol=value) - - def setQueryPartitionsCol(self, value): - """ - Sets the value of :py:attr:`queryPartitionsCol`. - """ - return self._set(queryPartitionsCol=value) - - def setParallelism(self, value): - """ - Sets the value of :py:attr:`parallelism`. - """ - return self._set(parallelism=value) - - def setNumPartitions(self, value): - """ - Sets the value of :py:attr:`numPartitions`. - """ - return self._set(numPartitions=value) - - def setNumReplicas(self, value): - """ - Sets the value of :py:attr:`numReplicas`. - """ - return self._set(numReplicas=value) - - def setK(self, value): - """ - Sets the value of :py:attr:`k`. - """ - return self._set(k=value) - - def setDistanceFunction(self, value): - """ - Sets the value of :py:attr:`distanceFunction`. - """ - return self._set(distanceFunction=value) - - def setExcludeSelf(self, value): - """ - Sets the value of :py:attr:`excludeSelf`. - """ - return self._set(excludeSelf=value) - - def setSimilarityThreshold(self, value): - """ - Sets the value of :py:attr:`similarityThreshold`. - """ - return self._set(similarityThreshold=value) - - def setOutputFormat(self, value): - """ - Sets the value of :py:attr:`outputFormat`. - """ - return self._set(outputFormat=value) - - def setM(self, value): - """ - Sets the value of :py:attr:`m`. - """ - return self._set(m=value) - - def setEf(self, value): - """ - Sets the value of :py:attr:`ef`. - """ - return self._set(ef=value) - - def setEfConstruction(self, value): - """ - Sets the value of :py:attr:`efConstruction`. - """ - return self._set(efConstruction=value) - - def setInitialModelPath(self, value): - """ - Sets the value of :py:attr:`initialModelPath`. - """ - return self._set(initialModelPath=value) - - @keyword_only - def setParams(self, identifierCol="id", queryIdentifierCol=None, queryPartitionsCol=None, parallelism=None, - featuresCol="features", predictionCol="prediction", m=16, ef=10, efConstruction=200, numPartitions=1, - numReplicas=0, k=5, distanceFunction="cosine", excludeSelf=False, similarityThreshold=-1.0, - outputFormat="full", initialModelPath=None): - kwargs = self._input_kwargs - return self._set(**kwargs) - - def _create_model(self, java_model): - return HnswSimilarityModel(java_model) - - -class HnswSimilarityModel(JavaModel, _HnswModelParams, JavaMLReadable, JavaMLWritable): - """ - Model fitted by Hnsw. - """ - - _classpath_model = 'com.github.jelmerk.spark.knn.hnsw.HnswSimilarityModel' - - def setQueryIdentifierCol(self, value): - """ - Sets the value of :py:attr:`queryIdentifierCol`. - """ - return self._set(queryIdentifierCol=value) - - def setQueryPartitionsCol(self, value): - """ - Sets the value of :py:attr:`queryPartitionsCol`. - """ - return self._set(queryPartitionsCol=value) - - def setParallelism(self, value): - """ - Sets the value of :py:attr:`parallelism`. - """ - return self._set(parallelism=value) - - def setK(self, value): - """ - Sets the value of :py:attr:`k`. - """ - return self._set(k=value) - - def setEf(self, value): - """ - Sets the value of :py:attr:`ef`. - """ - return self._set(ef=value) - - def setExcludeSelf(self, value): - """ - Sets the value of :py:attr:`excludeSelf`. - """ - return self._set(excludeSelf=value) - - def setSimilarityThreshold(self, value): - """ - Sets the value of :py:attr:`similarityThreshold`. - """ - return self._set(similarityThreshold=value) - - def setOutputFormat(self, value): - """ - Sets the value of :py:attr:`outputFormat`. - """ - return self._set(outputFormat=value) - - def setNumReplicas(self, value): - """ - Sets the value of :py:attr:`numReplicas`. - """ - return self._set(numReplicas=value) - - @classmethod - def read(cls): - return HnswLibMLReader(cls, cls._classpath_model) - - -HnswSimilarityModelImpl = HnswSimilarityModel -BruteForceSimilarityModelImpl = BruteForceSimilarityModel \ No newline at end of file diff --git a/hnswlib-spark/src/main/python/pyspark_hnsw/linalg.py b/hnswlib-spark/src/main/python/pyspark_hnsw/linalg.py deleted file mode 100644 index 2c7b9501..00000000 --- a/hnswlib-spark/src/main/python/pyspark_hnsw/linalg.py +++ /dev/null @@ -1,32 +0,0 @@ -from pyspark.ml.param.shared import HasInputCol, HasOutputCol -from pyspark.ml.wrapper import JavaTransformer -from pyspark.ml.util import JavaMLReadable, JavaMLWritable -from pyspark.mllib.common import inherit_doc -from pyspark import keyword_only - -__all__ = ['Normalizer'] - -@inherit_doc -class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): - """ - Normalizes vectors to unit norm - """ - - @keyword_only - def __init__(self, inputCol="input", outputCol="output"): - """ - __init__(self, inputCol="input", outputCol="output") - """ - super(Normalizer, self).__init__() - self._java_obj = self._new_java_obj("com.github.jelmerk.spark.linalg.Normalizer", self.uid) - kwargs = self._input_kwargs - self.setParams(**kwargs) - - @keyword_only - def setParams(self, inputCol="input", outputCol="output"): - """ - setParams(self, inputCol="input", outputCol="output") - Sets params for this Normalizer. - """ - kwargs = self._input_kwargs - return self._set(**kwargs) diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/HnswLibKryoRegistrator.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/HnswLibKryoRegistrator.scala deleted file mode 100644 index 70c7e2a4..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/HnswLibKryoRegistrator.scala +++ /dev/null @@ -1,18 +0,0 @@ -package com.github.jelmerk.spark - -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.serializers.JavaSerializer -import com.github.jelmerk.knn.scalalike.hnsw.HnswIndex -import com.github.jelmerk.spark.util.SerializableConfiguration -import org.apache.spark.serializer.KryoRegistrator - -/** - * Implementation of KryoRegistrator that registers hnswlib classes with spark. - * Can be registered by setting spark.kryo.registrator to com.github.jelmerk.spark.HnswLibKryoRegistrator - */ -class HnswLibKryoRegistrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo): Unit = { - kryo.register(classOf[HnswIndex[_, _, _, _]], new JavaSerializer) - kryo.register(classOf[SerializableConfiguration], new JavaSerializer) - } -} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/conversion/VectorConverter.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/conversion/VectorConverter.scala deleted file mode 100644 index 28756d58..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/conversion/VectorConverter.scala +++ /dev/null @@ -1,119 +0,0 @@ -package com.github.jelmerk.spark.conversion - -import com.github.jelmerk.spark.linalg.Normalizer -import org.apache.spark.internal.Logging -import org.apache.spark.ml.Transformer -import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.linalg.SQLDataTypes._ -import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType, FloatType, StructType} - -/** - * Companion class for VectorConverter. - */ -object VectorConverter extends DefaultParamsReadable[Normalizer] { - override def load(path: String): Normalizer = super.load(path) -} - -private[conversion] trait VectorConverterParams extends HasInputCol with HasOutputCol { - - /** - * Param for the type of vector to produce. one of array, array, vector - * Default: "array" - * - * @group param - */ - final val outputType: Param[String] = new Param[String](this, "outputType", "type of vector to produce") - - /** @group getParam */ - final def getOutputType: String = $(outputType) - - setDefault(outputType -> "array") -} - -/** - * Converts the input vector to a vector of another type. - * - * @param uid identifier - */ -class VectorConverter(override val uid: String) - extends Transformer with VectorConverterParams with Logging with DefaultParamsWritable { - - def this() = this(Identifiable.randomUID("conv")) - - /** @group setParam */ - def setInputCol(value: String): this.type = set(inputCol, value) - - /** @group setParam */ - def setOutputCol(value: String): this.type = set(outputCol, value) - - /** @group setParam */ - def setOutputType(value: String): this.type = set(outputType, value) - - override def transform(dataset: Dataset[_]): DataFrame = { - - dataset.withColumn(getOutputCol, (dataset.schema(getInputCol).dataType, getOutputType) match { - case (ArrayType(FloatType, _), "array") => floatArrayToDoubleArray(col(getInputCol)) - case (ArrayType(FloatType, _), "vector") => floatArrayToVector(col(getInputCol)) - - case (ArrayType(DoubleType, _), "array") => doubleArrayToFloatArray(col(getInputCol)) - case (ArrayType(DoubleType, _), "vector") => doubleArrayToVector(col(getInputCol)) - - case (VectorType, "array") => vectorToFloatArray(col(getInputCol)) - case (VectorType, "array") => vectorToDoubleArray(col(getInputCol)) - - case _ => throw new IllegalArgumentException("Cannot convert vector") - }) - } - - override def copy(extra: ParamMap): Transformer = defaultCopy(extra) - - override def transformSchema(schema: StructType): StructType = { - if (schema.fieldNames.contains(getOutputCol)) { - throw new IllegalArgumentException(s"Output column $getOutputCol already exists.") - } - - if (!schema.fieldNames.contains(getInputCol)) { - throw new IllegalArgumentException(s"Input column $getInputCol does not exist.") - } - - val inputColumnSchema = schema(getInputCol) - - val inputColHasValidDataType = inputColumnSchema.dataType match { - case VectorType => true - case ArrayType(DoubleType, _) => true - case _ => false - } - - if (!inputColHasValidDataType) { - throw new IllegalArgumentException(s"Input column $getInputCol must be a double array or vector.") - } - - val outputType: DataType = getOutputType match { - case "array" => ArrayType(DoubleType) - case "array" => ArrayType(FloatType) - case "vector" => VectorType - } - - schema - .add(getOutputCol, outputType, inputColumnSchema.nullable) - } - - private val vectorToFloatArray: UserDefinedFunction = udf { vector: Vector => vector.toArray.map(_.toFloat) } - - private val doubleArrayToFloatArray: UserDefinedFunction = udf { vector: Seq[Double] => vector.map(_.toFloat) } - - private val floatArrayToDoubleArray: UserDefinedFunction = udf { vector: Seq[Float] => vector.toArray.map(_.toDouble) } - - private val vectorToDoubleArray: UserDefinedFunction = udf { vector: Vector => vector.toArray } - - private val floatArrayToVector: UserDefinedFunction = udf { vector: Seq[Float] => Vectors.dense(vector.map(_.toDouble).toArray) } - - private val doubleArrayToVector: UserDefinedFunction = udf { vector: Seq[Double] => Vectors.dense(vector.toArray) } - -} \ No newline at end of file diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala deleted file mode 100644 index 823e7c23..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala +++ /dev/null @@ -1,1015 +0,0 @@ -package com.github.jelmerk.spark.knn - -import java.io.InputStream -import java.net.InetAddress -import java.util.concurrent.{CountDownLatch, ExecutionException, FutureTask, LinkedBlockingQueue, ThreadLocalRandom, ThreadPoolExecutor, TimeUnit} -import com.github.jelmerk.knn.ObjectSerializer - -import scala.language.{higherKinds, implicitConversions} -import scala.reflect.ClassTag -import scala.reflect.runtime.universe._ -import scala.util.Try -import scala.Seq -import org.apache.hadoop.fs.{FileUtil, Path} -import org.apache.spark.{Partitioner, TaskContext} -import org.apache.spark.internal.Logging -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.Vector -import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} -import org.apache.spark.ml.util.{MLReader, MLWriter} -import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} -import org.apache.spark.sql._ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ -import org.json4s.jackson.JsonMethods._ -import org.json4s._ -import com.github.jelmerk.knn.scalalike._ -import com.github.jelmerk.knn.util.NamedThreadFactory -import com.github.jelmerk.spark.linalg.functions.VectorDistanceFunctions -import com.github.jelmerk.spark.util.SerializableConfiguration -import org.apache.spark.ml.linalg.SQLDataTypes.VectorType -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder - -import scala.annotation.tailrec -import scala.util.control.NonFatal - - -private[knn] case class IntDoubleArrayIndexItem(id: Int, vector: Array[Double]) extends Item[Int, Array[Double]] { - override def dimensions: Int = vector.length -} - -private[knn] case class LongDoubleArrayIndexItem(id: Long, vector: Array[Double]) extends Item[Long, Array[Double]] { - override def dimensions: Int = vector.length -} - -private[knn] case class StringDoubleArrayIndexItem(id: String, vector: Array[Double]) extends Item[String, Array[Double]] { - override def dimensions: Int = vector.length -} - - -private[knn] case class IntFloatArrayIndexItem(id: Int, vector: Array[Float]) extends Item[Int, Array[Float]] { - override def dimensions: Int = vector.length -} - -private[knn] case class LongFloatArrayIndexItem(id: Long, vector: Array[Float]) extends Item[Long, Array[Float]] { - override def dimensions: Int = vector.length -} - -private[knn] case class StringFloatArrayIndexItem(id: String, vector: Array[Float]) extends Item[String, Array[Float]] { - override def dimensions: Int = vector.length -} - - -private[knn] case class IntVectorIndexItem(id: Int, vector: Vector) extends Item[Int, Vector] { - override def dimensions: Int = vector.size -} - -private[knn] case class LongVectorIndexItem(id: Long, vector: Vector) extends Item[Long, Vector] { - override def dimensions: Int = vector.size -} - -private[knn] case class StringVectorIndexItem(id: String, vector: Vector) extends Item[String, Vector] { - override def dimensions: Int = vector.size -} - -/** - * Neighbor of an item. - * - * @param neighbor identifies the neighbor - * @param distance distance to the item - * - * @tparam TId type of the index item identifier - * @tparam TDistance type of distance - */ -private[knn] case class Neighbor[TId, TDistance] (neighbor: TId, distance: TDistance) - -/** - * Common params for KnnAlgorithm and KnnModel. - */ -private[knn] trait KnnModelParams extends Params with HasFeaturesCol with HasPredictionCol { - - /** - * Param for the column name for the query identifier. - * - * @group param - */ - final val queryIdentifierCol = new Param[String](this, "queryIdentifierCol", "column name for the query identifier") - - /** @group getParam */ - final def getQueryIdentifierCol: String = $(queryIdentifierCol) - - /** - * Param for the column name for the query partitions. - * - * @group param - */ - final val queryPartitionsCol = new Param[String](this, "queryPartitionsCol", "column name for the query partitions") - - /** @group getParam */ - final def getQueryPartitionsCol: String = $(queryPartitionsCol) - - /** - * Param for number of neighbors to find (> 0). - * Default: 5 - * - * @group param - */ - final val k = new IntParam(this, "k", "number of neighbors to find", ParamValidators.gt(0)) - - /** @group getParam */ - final def getK: Int = $(k) - - /** - * Param that indicates whether to not return the a candidate when it's identifier equals the query identifier - * Default: false - * - * @group param - */ - final val excludeSelf = new BooleanParam(this, "excludeSelf", "whether to include the row identifier as a candidate neighbor") - - /** @group getParam */ - final def getExcludeSelf: Boolean = $(excludeSelf) - - /** - * Param for the threshold value for inclusion. -1 indicates no threshold - * Default: -1 - * - * @group param - */ - final val similarityThreshold = new DoubleParam(this, "similarityThreshold", "do not return neighbors further away than this distance") - - /** @group getParam */ - final def getSimilarityThreshold: Double = $(similarityThreshold) - - /** - * Param that specifies the number of index replicas to create when querying the index. More replicas means you can - * execute more queries in parallel at the expense of increased resource usage. - * Default: 0 - * - * @group param - */ - final val numReplicas = new IntParam(this, "numReplicas", "number of index replicas to create when querying") - - /** @group getParam */ - final def getNumReplicas: Int = $(numReplicas) - - /** - * Param that specifies the number of threads to use. - * Default: number of processors available to the Java virtual machine - * - * @group param - */ - final val parallelism = new IntParam(this, "parallelism", "number of threads to use") - - /** @group getParam */ - final def getParallelism: Int = $(parallelism) - - /** - * Param for the output format to produce. One of "full", "minimal" Setting this to minimal is more efficient - * when all you need is the identifier with its neighbors - * - * Default: "full" - * - * @group param - */ - final val outputFormat = new Param[String](this, "outputFormat", "output format to produce") - - /** @group getParam */ - final def getOutputFormat: String = $(outputFormat) - - setDefault(k -> 5, predictionCol -> "prediction", featuresCol -> "features", - excludeSelf -> false, similarityThreshold -> -1, outputFormat -> "full") - - protected def validateAndTransformSchema(schema: StructType, identifierDataType: DataType): StructType = { - - val distanceType = schema(getFeaturesCol).dataType match { - case ArrayType(FloatType, _) => FloatType - case _ => DoubleType - } - - val predictionStruct = new StructType() - .add("neighbor", identifierDataType, nullable = false) - .add("distance", distanceType, nullable = false) - - val neighborsField = StructField(getPredictionCol, new ArrayType(predictionStruct, containsNull = false)) - - getOutputFormat match { - case "minimal" if !isSet(queryIdentifierCol) => throw new IllegalArgumentException("queryIdentifierCol must be set when using outputFormat minimal.") - case "minimal" => - new StructType() - .add(schema(getQueryIdentifierCol)) - .add(neighborsField) - case _ => - if (schema.fieldNames.contains(getPredictionCol)) { - throw new IllegalArgumentException(s"Output column $getPredictionCol already exists.") - } - schema - .add(neighborsField) - } - } -} - -/** - * Params for knn algorithms. - */ -private[knn] trait KnnAlgorithmParams extends KnnModelParams { - - /** - * Param for the column name for the row identifier. - * Default: "id" - * - * @group param - */ - final val identifierCol = new Param[String](this, "identifierCol", "column name for the row identifier") - - /** @group getParam */ - final def getIdentifierCol: String = $(identifierCol) - - /** - * Number of partitions (default: 1) - */ - final val numPartitions = new IntParam(this, "numPartitions", - "number of partitions", ParamValidators.gt(0)) - - /** @group getParam */ - final def getNumPartitions: Int = $(numPartitions) - - /** - * Param for the distance function to use. One of "bray-curtis", "canberra", "cosine", "correlation", "euclidean", - * "inner-product", "manhattan" or the fully qualified classname of a distance function - * Default: "cosine" - * - * @group param - */ - final val distanceFunction = new Param[String](this, "distanceFunction", "distance function to use") - - /** @group getParam */ - final def getDistanceFunction: String = $(distanceFunction) - - /** - * Param for the partition identifier - */ - final val partitionCol = new Param[String](this, "partitionCol", "column name for the partition identifier") - - /** @group getParam */ - final def getPartitionCol: String = $(partitionCol) - - /** - * Param to the initial model. All the vectors from the initial model will included in the final output model. - */ - final val initialModelPath = new Param[String](this, "initialModelPath", "path to the initial model") - - /** @group getParam */ - final def getInitialModelPath: String = $(initialModelPath) - - setDefault(identifierCol -> "id", distanceFunction -> "cosine", numPartitions -> 1, numReplicas -> 0) -} - -/** - * Persists a knn model. - * - * @param instance the instance to persist - * - * @tparam TModel type of the model - * @tparam TId type of the index item identifier - * @tparam TVector type of the index item vector - * @tparam TItem type of the index item - * @tparam TDistance type of distance - * @tparam TIndex type of the index - */ -private[knn] class KnnModelWriter[ - TModel <: KnnModelBase[TModel], - TId: TypeTag, - TVector : TypeTag, - TItem <: Item[TId, TVector] with Product : TypeTag, - TDistance: TypeTag, - TIndex <: Index[TId, TVector, TItem, TDistance] -] (instance: TModel with KnnModelOps[TModel, TId, TVector, TItem, TDistance, TIndex]) - extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - val params = JObject( - instance.extractParamMap().toSeq.toList - // cannot use parse because of incompatibilities between json4s 3.2.11 used by spark 2.3 and 3.6.6 used by spark 2.4 - .map { case ParamPair(param, value) => - val fieldName = param.name - val fieldValue = mapper.readValue(param.jsonEncode(value), classOf[JValue]) - JField(fieldName, fieldValue) - } - ) - - val metaData = JObject(List( - JField("class", JString(instance.getClass.getName)), - JField("timestamp", JLong(System.currentTimeMillis())), - JField("sparkVersion", JString(sc.version)), - JField("uid", JString(instance.uid)), - JField("identifierType", JString(typeDescription[TId])), - JField("vectorType", JString(typeDescription[TVector])), - JField("partitions", JInt(instance.getNumPartitions)), - JField("paramMap", params) - )) - - val metadataPath = new Path(path, "metadata").toString - sc.parallelize(Seq(compact(metaData)), numSlices = 1).saveAsTextFile(metadataPath) - - val indicesPath = new Path(path, "indices").toString - - val modelOutputDir = instance.outputDir - - val serializableConfiguration = new SerializableConfiguration(sc.hadoopConfiguration) - - sc.range(start = 0, end = instance.getNumPartitions).foreach { partitionId => - val originPath = new Path(modelOutputDir, partitionId.toString) - val originFileSystem = originPath.getFileSystem(serializableConfiguration.value) - - if (originFileSystem.exists(originPath)) { - val destinationPath = new Path(indicesPath, partitionId.toString) - val destinationFileSystem = destinationPath.getFileSystem(serializableConfiguration.value) - FileUtil.copy(originFileSystem, originPath, destinationFileSystem, destinationPath, false, serializableConfiguration.value) - } - } - } - - private def typeDescription[T: TypeTag] = typeOf[T] match { - case t if t =:= typeOf[Int] => "int" - case t if t =:= typeOf[Long] => "long" - case t if t =:= typeOf[String] => "string" - case t if t =:= typeOf[Array[Float]] => "float_array" - case t if t =:= typeOf[Array[Double]] => "double_array" - case t if t =:= typeOf[Vector] => "vector" - case _ => "unknown" - } -} - -/** - * Reads a knn model from persistent storage. - * - * @param ev classtag - * @tparam TModel type of model - */ -private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]](implicit ev: ClassTag[TModel]) - extends MLReader[TModel] { - - private implicit val format: Formats = DefaultFormats - - override def load(path: String): TModel = { - - val metadataPath = new Path(path, "metadata").toString - - val metadataStr = sc.textFile(metadataPath, 1).first() - - // cannot use parse because of incompatibilities between json4s 3.2.11 used by spark 2.3 and 3.6.6 used by spark 2.4 - val metadata = mapper.readValue(metadataStr, classOf[JValue]) - - val uid = (metadata \ "uid").extract[String] - - val identifierType = (metadata \ "identifierType").extract[String] - val vectorType = (metadata \ "vectorType").extract[String] - val partitions = (metadata \ "partitions").extract[Int] - - val paramMap = (metadata \ "paramMap").extract[JObject] - - val indicesPath = new Path(path, "indices").toString - - val model = (identifierType, vectorType) match { - case ("int", "float_array") => createModel[Int, Array[Float], IntFloatArrayIndexItem, Float](uid, indicesPath, partitions) - case ("int", "double_array") => createModel[Int, Array[Double], IntDoubleArrayIndexItem, Double](uid, indicesPath, partitions) - case ("int", "vector") => createModel[Int, Vector, IntVectorIndexItem, Double](uid, indicesPath, partitions) - - case ("long", "float_array") => createModel[Long, Array[Float], LongFloatArrayIndexItem, Float](uid, indicesPath, partitions) - case ("long", "double_array") => createModel[Long, Array[Double], LongDoubleArrayIndexItem, Double](uid, indicesPath, partitions) - case ("long", "vector") => createModel[Long, Vector, LongVectorIndexItem, Double](uid, indicesPath, partitions) - - case ("string", "float_array") => createModel[String, Array[Float], StringFloatArrayIndexItem, Float](uid, indicesPath, partitions) - case ("string", "double_array") => createModel[String, Array[Double], StringDoubleArrayIndexItem, Double](uid, indicesPath, partitions) - case ("string", "vector") => createModel[String, Vector, StringVectorIndexItem, Double](uid, indicesPath, partitions) - case _ => throw new IllegalStateException(s"Cannot create model for identifier type $identifierType and vector type $vectorType.") - } - - paramMap.obj.foreach { case (paramName, jsonValue) => - val param = model.getParam(paramName) - model.set(param, param.jsonDecode(compact(render(jsonValue)))) - } - - model - } - - /** - * Creates the model to be returned from fitting the data. - * - * @param uid identifier - * @param outputDir directory containing the persisted indices - * @param numPartitions number of index partitions - * - * @tparam TId type of the index item identifier - * @tparam TVector type of the index item vector - * @tparam TItem type of the index item - * @tparam TDistance type of distance between items - * @return model - */ - protected def createModel[ - TId : TypeTag, - TVector : TypeTag, - TItem <: Item[TId, TVector] with Product : TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int) - (implicit ev: ClassTag[TId], evVector: ClassTag[TVector], distanceNumeric: Numeric[TDistance]) : TModel - -} - -/** - * Base class for nearest neighbor search models. - * - * @tparam TModel type of the model - **/ -private[knn] abstract class KnnModelBase[TModel <: KnnModelBase[TModel]] extends Model[TModel] with KnnModelParams { - - private[knn] def outputDir: String - - def getNumPartitions: Int - - /** @group setParam */ - def setQueryIdentifierCol(value: String): this.type = set(queryIdentifierCol, value) - - /** @group setParam */ - def setQueryPartitionsCol(value: String): this.type = set(queryPartitionsCol, value) - - /** @group setParam */ - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - - /** @group setParam */ - def setPredictionCol(value: String): this.type = set(predictionCol, value) - - /** @group setParam */ - def setK(value: Int): this.type = set(k, value) - - /** @group setParam */ - def setExcludeSelf(value: Boolean): this.type = set(excludeSelf, value) - - /** @group setParam */ - def setSimilarityThreshold(value: Double): this.type = set(similarityThreshold, value) - - /** @group setParam */ - def setNumReplicas(value: Int): this.type = set(numReplicas, value) - - /** @group setParam */ - def setParallelism(value: Int): this.type = set(parallelism, value) - - /** @group setParam */ - def setOutputFormat(value: String): this.type = set(outputFormat, value) - -} - -/** - * Contains the core knn search logic - * - * @tparam TModel type of the model - * @tparam TId type of the index item identifier - * @tparam TVector type of the index item vector - * @tparam TItem type of the index item - * @tparam TDistance type of distance between items - * @tparam TIndex type of the index - */ -private[knn] trait KnnModelOps[ - TModel <: KnnModelBase[TModel], - TId, - TVector, - TItem <: Item[TId, TVector] with Product, - TDistance, - TIndex <: Index[TId, TVector, TItem, TDistance] -] { - this: TModel with KnnModelParams => - - protected def loadIndex(in: InputStream): TIndex - - protected def typedTransform(dataset: Dataset[_]) - (implicit tId: TypeTag[TId], tVector: TypeTag[TVector], tDistance: TypeTag[TDistance], evId: ClassTag[TId], evVector: ClassTag[TVector], distanceNumeric: Numeric[TDistance]) : DataFrame = { - - if (!isSet(queryIdentifierCol) && getExcludeSelf) { - throw new IllegalArgumentException("QueryIdentifierCol must be defined when excludeSelf is true.") - } - - if (isSet(queryIdentifierCol)) typedTransformWithQueryCol[TId](dataset, getQueryIdentifierCol) - else typedTransformWithQueryCol[Long](dataset.withColumn("_query_id", monotonically_increasing_id), "_query_id").drop("_query_id") - } - - protected def typedTransformWithQueryCol[TQueryId](dataset: Dataset[_], queryIdCol: String) - (implicit tId: TypeTag[TId], tVector: TypeTag[TVector], tDistance: TypeTag[TDistance], tQueryId: TypeTag[TQueryId], evId: ClassTag[TId], evVector: ClassTag[TVector], evQueryId: ClassTag[TQueryId], distanceNumeric: Numeric[TDistance]) : DataFrame = { - import dataset.sparkSession.implicits._ - import distanceNumeric._ - - implicit val encoder: Encoder[TQueryId] = ExpressionEncoder() - implicit val neighborOrdering: Ordering[Neighbor[TId, TDistance]] = Ordering.by(_.distance) - - val serializableHadoopConfiguration = new SerializableConfiguration(dataset.sparkSession.sparkContext.hadoopConfiguration) - - // construct the queries to the distributed indices. when query partitions are specified we only query those partitions - // otherwise we query all partitions - val logicalPartitionAndQueries = - if (isDefined(queryPartitionsCol)) dataset - .select( - col(getQueryPartitionsCol), - col(queryIdCol), - col(getFeaturesCol) - ) - .as[(Seq[Int], TQueryId, TVector)] - .rdd - .flatMap { case (queryPartitions, queryId, vector) => - queryPartitions.map { partition => (partition, (queryId, vector)) } - } - else dataset - .select( - col(queryIdCol), - col(getFeaturesCol) - ) - .as[(TQueryId, TVector)] - .rdd - .flatMap { case (queryId, vector) => - Range(0, getNumPartitions).map { partition => - (partition, (queryId, vector)) - } - } - - val numPartitionCopies = getNumReplicas + 1 - - val physicalPartitionAndQueries = logicalPartitionAndQueries - .map { case (partition, (queryId, vector)) => - val randomCopy = ThreadLocalRandom.current().nextInt(numPartitionCopies) - val physicalPartition = (partition * numPartitionCopies) + randomCopy - (physicalPartition, (queryId, vector)) - } - .partitionBy(new PartitionIdPassthrough(getNumPartitions * numPartitionCopies)) - - val numThreads = - if (isSet(parallelism) && getParallelism <= 0) sys.runtime.availableProcessors - else if (isSet(parallelism)) getParallelism - else dataset.sparkSession.sparkContext.getConf.getInt("spark.task.cpus", defaultValue = 1) - - val neighborsOnAllQueryPartitions = physicalPartitionAndQueries - .mapPartitions { queriesWithPartition => - - val queries = queriesWithPartition.map(_._2) - - // load the partitioned index and execute all queries. - - val physicalPartitionId = TaskContext.getPartitionId() - - val logicalPartitionId = physicalPartitionId / numPartitionCopies - val replica = physicalPartitionId % numPartitionCopies - - val indexPath = new Path(outputDir, logicalPartitionId.toString) - - val fileSystem = indexPath.getFileSystem(serializableHadoopConfiguration.value) - - if (!fileSystem.exists(indexPath)) Iterator.empty - else { - - logInfo(logicalPartitionId, replica, s"started loading index from $indexPath on host ${InetAddress.getLocalHost.getHostName}") - val index = loadIndex(fileSystem.open(indexPath)) - logInfo(logicalPartitionId, replica, s"finished loading index from $indexPath on host ${InetAddress.getLocalHost.getHostName}") - - // execute queries in parallel on multiple threads - new Iterator[(TQueryId, Seq[Neighbor[TId, TDistance]])] { - - private[this] var first = true - private[this] var count = 0 - - private[this] val batchSize = 1000 - private[this] val queue = new LinkedBlockingQueue[(TQueryId, Seq[Neighbor[TId, TDistance]])](batchSize * numThreads) - private[this] val executorService = new ThreadPoolExecutor(numThreads, numThreads, 60L, - TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable], new NamedThreadFactory("searcher-%d")) { - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - - Option(t).orElse { - r match { - case t: FutureTask[_] => Try(t.get()).failed.toOption.map { - case e: ExecutionException => e.getCause - case e: InterruptedException => - Thread.currentThread().interrupt() - e - case NonFatal(e) => e - } - case _ => None - } - }.foreach { e => - logError("Error in worker.", e) - } - } - } - executorService.allowCoreThreadTimeOut(true) - - private[this] val activeWorkers = new CountDownLatch(numThreads) - Range(0, numThreads).map(_ => new Worker(queries, activeWorkers, batchSize)).foreach(executorService.submit) - - override def hasNext: Boolean = { - if (!queue.isEmpty) true - else if (queries.synchronized { queries.hasNext }) true - else { - // in theory all workers could have just picked up the last new work but not started processing any of it. - if (!activeWorkers.await(2, TimeUnit.MINUTES)) { - throw new IllegalStateException("Workers failed to complete.") - } - !queue.isEmpty - } - } - - override def next(): (TQueryId, Seq[Neighbor[TId, TDistance]]) = { - if (first) { - logInfo(logicalPartitionId, replica, s"started querying on host ${InetAddress.getLocalHost.getHostName} with ${sys.runtime.availableProcessors} available processors.") - first = false - } - - val value = queue.poll(1, TimeUnit.MINUTES) - - count += 1 - - if (!hasNext) { - logInfo(logicalPartitionId, replica, s"finished querying $count items on host ${InetAddress.getLocalHost.getHostName}") - - executorService.shutdown() - } - - value - } - - class Worker(queries: Iterator[(TQueryId, TVector)], activeWorkers: CountDownLatch, batchSize: Int) extends Runnable { - - private[this] var work = List.empty[(TQueryId, TVector)] - - private[this] val fetchSize = - if (getExcludeSelf) getK + 1 - else getK - - @tailrec final override def run(): Unit = { - - work.foreach { case (id, vector) => - - val neighbors = index.findNearest(vector, fetchSize) - .collect { case SearchResult(item, distance) - if (!getExcludeSelf || item.id != id) && (getSimilarityThreshold < 0 || distance.toDouble < getSimilarityThreshold) => - Neighbor[TId, TDistance](item.id, distance) - } - - queue.put(id -> neighbors) - } - - work = queries.synchronized { - queries.take(batchSize).toList - } - - if (work.nonEmpty) { - run() - } else { - activeWorkers.countDown() - } - } - } - } - } - }.toDS() - - // take the best k results from all partitions - - val topNeighbors = neighborsOnAllQueryPartitions - .groupByKey { case (queryId, _) => queryId } - .flatMapGroups { (queryId, groups) => - val allNeighbors = groups.flatMap { case (_, neighbors) => neighbors}.toList - Iterator.single(queryId -> allNeighbors.sortBy(_.distance).take(getK)) - } - .toDF(queryIdCol, getPredictionCol) - - if (getOutputFormat == "minimal") topNeighbors - else dataset.join(topNeighbors, Seq(queryIdCol)) - } - - protected def typedTransformSchema[T: TypeTag](schema: StructType): StructType = { - val idDataType = typeOf[T] match { - case t if t =:= typeOf[Int] => IntegerType - case t if t =:= typeOf[Long] => LongType - case _ => StringType - } - validateAndTransformSchema(schema, idDataType) - } - - private def logInfo(partition: Int, replica: Int, message: String): Unit = - logInfo(f"partition $partition%04d replica $replica%04d: $message") - -} - -private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](override val uid: String) - extends Estimator[TModel] with KnnAlgorithmParams { - - /** - * Type of index. - * - * @tparam TId Type of the external identifier of an item - * @tparam TVector Type of the vector to perform distance calculation on - * @tparam TItem Type of items stored in the index - * @tparam TDistance Type of distance between items (expect any numeric type: float, double, int, ..) - */ - protected type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] <: Index[TId, TVector, TItem, TDistance] - - /** @group setParam */ - def setIdentifierCol(value: String): this.type = set(identifierCol, value) - - /** @group setParam */ - def setQueryIdentifierCol(value: String): this.type = set(queryIdentifierCol, value) - - /** @group setParam */ - def setPartitionCol(value: String): this.type = set(partitionCol, value) - - /** @group setParam */ - def setQueryPartitionsCol(value: String): this.type = set(queryPartitionsCol, value) - - /** @group setParam */ - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - - /** @group setParam */ - def setPredictionCol(value: String): this.type = set(predictionCol, value) - - /** @group setParam */ - def setK(value: Int): this.type = set(k, value) - - /** @group setParam */ - def setNumPartitions(value: Int): this.type = set(numPartitions, value) - - /** @group setParam */ - def setDistanceFunction(value: String): this.type = set(distanceFunction, value) - - /** @group setParam */ - def setExcludeSelf(value: Boolean): this.type = set(excludeSelf, value) - - /** @group setParam */ - def setSimilarityThreshold(value: Double): this.type = set(similarityThreshold, value) - - /** @group setParam */ - def setNumReplicas(value: Int): this.type = set(numReplicas, value) - - /** @group setParam */ - def setParallelism(value: Int): this.type = set(parallelism, value) - - /** @group setParam */ - def setOutputFormat(value: String): this.type = set(outputFormat, value) - - def setInitialModelPath(value: String): this.type = set(initialModelPath, value) - - override def fit(dataset: Dataset[_]): TModel = { - - val identifierType = dataset.schema(getIdentifierCol).dataType - val vectorType = dataset.schema(getFeaturesCol).dataType - - val model = (identifierType, vectorType) match { - case (IntegerType, ArrayType(FloatType, _)) => typedFit[Int, Array[Float], IntFloatArrayIndexItem, Float](dataset) - case (IntegerType, ArrayType(DoubleType, _)) => typedFit[Int, Array[Double], IntDoubleArrayIndexItem, Double](dataset) - case (IntegerType, VectorType) => typedFit[Int, Vector, IntVectorIndexItem, Double](dataset) - case (LongType, ArrayType(FloatType, _)) => typedFit[Long, Array[Float], LongFloatArrayIndexItem, Float](dataset) - case (LongType, ArrayType(DoubleType, _)) => typedFit[Long, Array[Double], LongDoubleArrayIndexItem, Double](dataset) - case (LongType, VectorType) => typedFit[Long, Vector, LongVectorIndexItem, Double](dataset) - case (StringType, ArrayType(FloatType, _)) => typedFit[String, Array[Float], StringFloatArrayIndexItem, Float](dataset) - case (StringType, ArrayType(DoubleType, _)) => typedFit[String, Array[Double], StringDoubleArrayIndexItem, Double](dataset) - case (StringType, VectorType) => typedFit[String, Vector, StringVectorIndexItem, Double](dataset) - case _ => - throw new IllegalArgumentException(s"Cannot create index for items with identifier of type " + - s"${identifierType.simpleString} and vector of type ${vectorType.simpleString}. " + - s"Supported identifiers are string, int, long and string. Supported vectors are array, array and vector ") - } - - copyValues(model) - } - - override def transformSchema(schema: StructType): StructType = - validateAndTransformSchema(schema, schema(getIdentifierCol).dataType) - - override def copy(extra: ParamMap): Estimator[TModel] = defaultCopy(extra) - - /** - * Create the index used to do the nearest neighbor search. - * - * @param dimensions dimensionality of the items stored in the index - * @param maxItemCount maximum number of items the index can hold - * @param distanceFunction the distance function - * @param distanceOrdering the distance ordering - * @param idSerializer invoked for serializing ids when saving the index - * @param itemSerializer invoked for serializing items when saving items - * - * @tparam TId type of the index item identifier - * @tparam TVector type of the index item vector - * @tparam TItem type of the index item - * @tparam TDistance type of distance between items - * @return create an index - */ - protected def createIndex[ - TId, - TVector, - TItem <: Item[TId, TVector] with Product, - TDistance - ](dimensions: Int, maxItemCount: Int, distanceFunction: DistanceFunction[TVector, TDistance]) - (implicit distanceOrdering: Ordering[TDistance], idSerializer: ObjectSerializer[TId], itemSerializer: ObjectSerializer[TItem]) - : TIndex[TId, TVector, TItem, TDistance] - - /** - * Load an index - * - * @param inputStream InputStream to restore the index from - * @param minCapacity loaded index needs to have space for at least this man additional items - * - * @tparam TId type of the index item identifier - * @tparam TVector type of the index item vector - * @tparam TItem type of the index item - * @tparam TDistance type of distance between items - * @return create an index - */ - protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance](inputStream: InputStream, - minCapacity: Int) - : TIndex[TId, TVector, TItem, TDistance] - - /** - * Creates the model to be returned from fitting the data. - * - * @param uid identifier - * @param outputDir directory containing the persisted indices - * @param numPartitions number of index partitions - * - * @tparam TId type of the index item identifier - * @tparam TVector type of the index item vector - * @tparam TItem type of the index item - * @tparam TDistance type of distance between items - * @return model - */ - protected def createModel[ - TId : TypeTag, - TVector : TypeTag, - TItem <: Item[TId, TVector] with Product : TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int) - (implicit ev: ClassTag[TId], evVector: ClassTag[TVector], distanceNumeric: Numeric[TDistance]) - : TModel - - private def typedFit[ - TId : TypeTag, - TVector : TypeTag, - TItem <: Item[TId, TVector] with Product : TypeTag, - TDistance: TypeTag - ](dataset: Dataset[_]) - (implicit ev: ClassTag[TId], evVector: ClassTag[TVector], evItem: ClassTag[TItem], distanceNumeric: Numeric[TDistance], distanceFunctionFactory: String => DistanceFunction[TVector, TDistance], idSerializer: ObjectSerializer[TId], itemSerializer: ObjectSerializer[TItem]) - : TModel = { - - val sc = dataset.sparkSession - val sparkContext = sc.sparkContext - - val serializableHadoopConfiguration = new SerializableConfiguration(sparkContext.hadoopConfiguration) - - import sc.implicits._ - - val cacheFolder = sparkContext.getConf.get(key = "spark.hnswlib.settings.index.cache_folder", defaultValue = "/tmp") - - val outputDir = new Path(cacheFolder,s"${uid}_${System.currentTimeMillis()}").toString - - sparkContext.addSparkListener(new CleanupListener(outputDir, serializableHadoopConfiguration)) - - // read the id and vector from the input dataset and and repartition them over numPartitions amount of partitions. - // if the data is pre-partitioned by the user repartition the input data by the user defined partition key, use a - // hash of the item id otherwise. - val partitionedIndexItems = { - if (isDefined(partitionCol)) dataset - .select( - col(getPartitionCol).as("partition"), - struct(col(getIdentifierCol).as("id"), col(getFeaturesCol).as("vector")) - ) - .as[(Int, TItem)] - .rdd - .partitionBy(new PartitionIdPassthrough(getNumPartitions)) - .values - .toDS - else dataset - .select( - col(getIdentifierCol).as("id"), - col(getFeaturesCol).as("vector")) - .as[TItem] - .repartition(getNumPartitions, $"id") - } - - // On each partition collect all the items into memory and construct the HNSW indices. - // Save these indices to the hadoop filesystem - - val numThreads = - if (isSet(parallelism) && getParallelism <= 0) sys.runtime.availableProcessors - else if (isSet(parallelism)) getParallelism - else dataset.sparkSession.sparkContext.getConf.getInt("spark.task.cpus", defaultValue = 1) - - val initialModelOutputDir = - if (isSet(initialModelPath)) Some(new Path(getInitialModelPath, "indices").toString) - else None - - val serializableConfiguration = new SerializableConfiguration(sparkContext.hadoopConfiguration) - - partitionedIndexItems - .foreachPartition { it: Iterator[TItem] => - if (it.hasNext) { - val partitionId = TaskContext.getPartitionId() - - val items = it.toSeq - - logInfo(partitionId,s"started indexing ${items.size} items on host ${InetAddress.getLocalHost.getHostName}") - - val existingIndexOption = initialModelOutputDir - .flatMap { dir => - val indexPath = new Path(dir, partitionId.toString) - val fs = indexPath.getFileSystem(serializableConfiguration.value) - - if (fs.exists(indexPath)) Some { - val inputStream = fs.open(indexPath) - loadIndex[TId, TVector, TItem, TDistance](inputStream, items.size) - } else { - logInfo(partitionId, s"File $indexPath not found.") - None - } - } - - val index = existingIndexOption - .getOrElse(createIndex[TId, TVector, TItem, TDistance](items.head.dimensions, items.size, distanceFunctionFactory(getDistanceFunction))) - - index.addAll(items, progressUpdateInterval = 5000, listener = (workDone, max) => logDebug(f"partition $partitionId%04d: Indexed $workDone of $max items"), numThreads = numThreads) - - logInfo(partitionId, s"finished indexing ${items.size} items on host ${InetAddress.getLocalHost.getHostName}") - - val path = new Path(outputDir, partitionId.toString) - val fileSystem = path.getFileSystem(serializableHadoopConfiguration.value) - - val outputStream = fileSystem.create(path) - - logInfo(partitionId, s"started saving index to $path on host ${InetAddress.getLocalHost.getHostName}") - - index.save(outputStream) - - logInfo(partitionId, s"finished saving index to $path on host ${InetAddress.getLocalHost.getHostName}") - } - } - - createModel[TId, TVector, TItem, TDistance](uid, outputDir, getNumPartitions) - } - - private def logInfo(partition: Int, message: String): Unit = logInfo(f"partition $partition%04d: $message") - - implicit private def floatArrayDistanceFunction(name: String): DistanceFunction[Array[Float], Float] = name match { - case "bray-curtis" => floatBrayCurtisDistance - case "canberra" => floatCanberraDistance - case "correlation" => floatCorrelationDistance - case "cosine" => floatCosineDistance - case "euclidean" => floatEuclideanDistance - case "inner-product" => floatInnerProduct - case "manhattan" => floatManhattanDistance - case value => userDistanceFunction(value) - } - - implicit private def doubleArrayDistanceFunction(name: String): DistanceFunction[Array[Double], Double] = name match { - case "bray-curtis" => doubleBrayCurtisDistance - case "canberra" => doubleCanberraDistance - case "correlation" => doubleCorrelationDistance - case "cosine" => doubleCosineDistance - case "euclidean" => doubleEuclideanDistance - case "inner-product" => doubleInnerProduct - case "manhattan" => doubleManhattanDistance - case value => userDistanceFunction(value) - } - - implicit private def vectorDistanceFunction(name: String): DistanceFunction[Vector, Double] = name match { - case "bray-curtis" => VectorDistanceFunctions.brayCurtisDistance - case "canberra" => VectorDistanceFunctions.canberraDistance - case "correlation" => VectorDistanceFunctions.correlationDistance - case "cosine" => VectorDistanceFunctions.cosineDistance - case "euclidean" => VectorDistanceFunctions.euclideanDistance - case "inner-product" => VectorDistanceFunctions.innerProduct - case "manhattan" => VectorDistanceFunctions.manhattanDistance - case value => userDistanceFunction(value) - } - - private def userDistanceFunction[TVector, TDistance](name: String): DistanceFunction[TVector, TDistance] = - Try(Class.forName(name).getDeclaredConstructor().newInstance()) - .toOption - .collect { case f: DistanceFunction[TVector @unchecked, TDistance @unchecked] => f } - .getOrElse(throw new IllegalArgumentException(s"$name is not a valid distance functions.")) -} - -private[knn] class CleanupListener(dir: String, serializableConfiguration: SerializableConfiguration) extends SparkListener with Logging { - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - - val path = new Path(dir) - val fileSystem = path.getFileSystem(serializableConfiguration.value) - - logInfo(s"Deleting files below $dir") - fileSystem.delete(path, true) - } -} - -/** - * Partitioner that uses precomputed partitions - * - * @param numPartitions number of partitions - */ -private[knn] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner { - override def getPartition(key: Any): Int = key.asInstanceOf[Int] -} - diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala deleted file mode 100644 index 343299ff..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala +++ /dev/null @@ -1,103 +0,0 @@ -package com.github.jelmerk.spark.knn.bruteforce - -import java.io.InputStream - -import com.github.jelmerk.knn.ObjectSerializer - -import scala.reflect.runtime.universe._ -import scala.reflect.ClassTag -import com.github.jelmerk.knn.scalalike.{DistanceFunction, Item} -import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.{Identifiable, MLReadable, MLReader, MLWritable, MLWriter} -import com.github.jelmerk.knn.scalalike.bruteforce.BruteForceIndex -import com.github.jelmerk.spark.knn._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Dataset} - -/** - * Companion class for BruteForceSimilarityModel. - */ -object BruteForceSimilarityModel extends MLReadable[BruteForceSimilarityModel] { - - private[knn] class BruteForceModelReader extends KnnModelReader[BruteForceSimilarityModel] { - - override protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance : TypeTag - ](uid: String, outputDir: String, numPartitions: Int) - (implicit evId: ClassTag[TId], evVector: ClassTag[TVector], distanceNumeric: Numeric[TDistance]) : BruteForceSimilarityModel = - new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) - - } - - override def read: MLReader[BruteForceSimilarityModel] = new BruteForceModelReader -} - -/** - * Model produced by `BruteForceSimilarity`. - */ -abstract class BruteForceSimilarityModel extends KnnModelBase[BruteForceSimilarityModel] with KnnModelParams with MLWritable - - -private[knn] class BruteForceSimilarityModelImpl[ - TId : TypeTag, - TVector : TypeTag, - TItem <: Item[TId, TVector] with Product : TypeTag, - TDistance : TypeTag -](override val uid: String, val outputDir: String, numPartitions: Int) - (implicit evId: ClassTag[TId], evVector: ClassTag[TVector], distanceNumeric: Numeric[TDistance]) - extends BruteForceSimilarityModel with KnnModelOps[BruteForceSimilarityModel, TId, TVector, TItem, TDistance, BruteForceIndex[TId, TVector, TItem, TDistance]] { - - override def getNumPartitions: Int = numPartitions - - override def transform(dataset: Dataset[_]): DataFrame = typedTransform(dataset) - - override def copy(extra: ParamMap): BruteForceSimilarityModel = { - val copied = new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) - copyValues(copied, extra).setParent(parent) - } - - override def transformSchema(schema: StructType): StructType = typedTransformSchema[TId](schema) - - override def write: MLWriter = new KnnModelWriter[BruteForceSimilarityModel, TId, TVector, TItem, TDistance, BruteForceIndex[TId, TVector, TItem, TDistance]](this) - - override protected def loadIndex(in: InputStream): BruteForceIndex[TId, TVector, TItem, TDistance] = - BruteForceIndex.loadFromInputStream[TId, TVector, TItem, TDistance](in) - -} - -/** - * Nearest neighbor search using a brute force approach. This will be very slow. It is in most cases not recommended - * for production use. But can be used to determine the accuracy of an approximative index. - * - * @param uid identifier - */ -class BruteForceSimilarity(override val uid: String) extends KnnAlgorithm[BruteForceSimilarityModel](uid) { - - override protected type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] = BruteForceIndex[TId, TVector, TItem, TDistance] - - def this() = this(Identifiable.randomUID("brute_force")) - - override protected def createIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance] - (dimensions: Int, maxItemCount: Int, distanceFunction: DistanceFunction[TVector, TDistance])(implicit distanceOrdering: Ordering[TDistance], idSerializer: ObjectSerializer[TId], itemSerializer: ObjectSerializer[TItem]) - : BruteForceIndex[TId, TVector, TItem, TDistance] = - BruteForceIndex[TId, TVector, TItem, TDistance]( - dimensions, - distanceFunction - ) - - override protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance] - (inputStream: InputStream, minCapacity: Int): BruteForceIndex[TId, TVector, TItem, TDistance] = BruteForceIndex.loadFromInputStream(inputStream) - - override protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance : TypeTag - ](uid: String, outputDir: String, numPartitions: Int) - (implicit evId: ClassTag[TId], evVector: ClassTag[TVector], distanceNumeric: Numeric[TDistance]) : BruteForceSimilarityModel = - new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) - -} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/evaluation/KnnSimilarityEvaluator.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/evaluation/KnnSimilarityEvaluator.scala deleted file mode 100644 index 75fb3add..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/evaluation/KnnSimilarityEvaluator.scala +++ /dev/null @@ -1,119 +0,0 @@ -package com.github.jelmerk.spark.knn.evaluation - -import scala.reflect.runtime.universe._ -import org.apache.spark.ml.evaluation.Evaluator -import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} -import org.apache.spark.sql.Dataset -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{ArrayType, IntegerType, LongType, StringType, StructField, StructType} - -/** - * Companion class for KnnSimilarityEvaluator. - */ -object KnnSimilarityEvaluator extends DefaultParamsReadable[KnnSimilarityEvaluator] { - override def load(path: String): KnnSimilarityEvaluator = super.load(path) -} - -/** - * Evaluator for knn algorithms, which expects two input columns, the exact neighbors and approximate neighbors. It compares - * the results to determine the accuracy of the approximate results. Typically you will want to compute this over a - * small sample given the cost of computing the exact results on a large index. - * - * @param uid identifier - */ -class KnnSimilarityEvaluator(override val uid: String) extends Evaluator with DefaultParamsWritable { - - def this() = this(Identifiable.randomUID("knn_eval")) - - /** - * Param for the column name for the approximate results. - * Default: "approximateNeighbors" - * - * @group param - */ - final val approximateNeighborsCol = new Param[String](this, "approximateNeighborsCol", "column containing the approximate neighbors") - - /** - * @group getParam - */ - final def getApproximateNeighborsCol: String = $(approximateNeighborsCol) - - /** - * @group setParam - */ - final def setApproximateNeighborsCol(value: String): this.type = set(approximateNeighborsCol, value) - - /** - * Param for the column name for the exact results. - * Default: "exactNeighbors" - * - * @group param - */ - final val exactNeighborsCol = new Param[String](this, "exactNeighborsCol", "column containing the exact neighbors") - - /** - * @group getParam - */ - final def getExactNeighborsCol: String = $(exactNeighborsCol) - - /** - * @group setParam - */ - final def setExactNeighborsCol(value: String): this.type = set(exactNeighborsCol, value) - - /** - * Returns the accuracy of the approximate results. - * - * @param dataset a dataset - * @return the accuracy of the approximate results - */ - override def evaluate(dataset: Dataset[_]): Double = { - if (!dataset.schema.fieldNames.contains(getExactNeighborsCol)) throw new IllegalArgumentException(s"Column $getExactNeighborsCol does not exist.") - if (!dataset.schema.fieldNames.contains(getApproximateNeighborsCol)) throw new IllegalArgumentException(s"Column $getApproximateNeighborsCol does not exist.") - - (dataset.schema(getExactNeighborsCol).dataType, dataset.schema(getApproximateNeighborsCol).dataType) match { - case (ArrayType(StructType(Array(StructField("neighbor", IntegerType, _, _), - StructField("distance", _, _, _))), _), - ArrayType(StructType(Array(StructField("neighbor", IntegerType, _, _), - StructField("distance", _, _, _))), _)) => typedEvaluate[Int](dataset) - - case (ArrayType(StructType(Array(StructField("neighbor", LongType, _, _), - StructField("distance", _, _, _))), _), - ArrayType(StructType(Array(StructField("neighbor", LongType, _, _), - StructField("distance", _, _, _))), _)) => typedEvaluate[Long](dataset) - - case (ArrayType(StructType(Array(StructField("neighbor", StringType, _, _), - StructField("distance", _, _, _))), _), - ArrayType(StructType(Array(StructField("neighbor", StringType, _, _), - StructField("distance", _, _, _))), _)) => typedEvaluate[String](dataset) - - case _ => throw new IllegalArgumentException(s"Column $getExactNeighborsCol and or $getApproximateNeighborsCol is not of the correct type.") - } - } - - private def typedEvaluate[TId : TypeTag](dataset: Dataset[_]): Double = { - import dataset.sparkSession.implicits._ - - dataset - .select( - col(s"$getExactNeighborsCol.neighbor"), - col(s"$getApproximateNeighborsCol.neighbor") - ) - .as[(Seq[TId], Seq[TId])] - .mapPartitions( it => it.map { case (exactNeighbors, approximateNeighbors) => - exactNeighbors.toSet.intersect(approximateNeighbors.toSet).size -> exactNeighbors.size - }) - .toDF("numMatching", "numResults") - .select(when(sum($"numResults") === 0, 1.0).otherwise(sum($"numMatching") / sum($"numResults"))) - .as[Double] - .collect() - .head - } - - override def copy(extra: ParamMap): Evaluator = this.defaultCopy(extra) - - override def isLargerBetter: Boolean = true - - setDefault(approximateNeighborsCol -> "approximateNeighbors", exactNeighborsCol -> "exactNeighbors") -} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala deleted file mode 100644 index 21d8c2a7..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala +++ /dev/null @@ -1,181 +0,0 @@ -package com.github.jelmerk.spark.knn.hnsw - -import java.io.InputStream - -import com.github.jelmerk.knn - -import scala.reflect.runtime.universe._ -import scala.reflect.ClassTag -import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.{Identifiable, MLReadable, MLReader, MLWritable, MLWriter} -import org.apache.spark.sql.{DataFrame, Dataset} -import com.github.jelmerk.knn.scalalike.{DistanceFunction, Item} -import com.github.jelmerk.knn.scalalike.hnsw._ -import com.github.jelmerk.spark.knn._ -import org.apache.spark.sql.types.StructType - -private[hnsw] trait HnswParams extends KnnAlgorithmParams with HnswModelParams { - - /** - * The number of bi-directional links created for every new element during construction. - * - * Default: 16 - * - * @group param - */ - final val m = new IntParam(this, "m", - "number of bi-directional links created for every new element during construction", ParamValidators.gt(0)) - - /** @group getParam */ - final def getM: Int = $(m) - - /** - * Has the same meaning as ef, but controls the index time / index precision. - * Default: 200 - * - * @group param - */ - final val efConstruction = new IntParam(this, "efConstruction", - "has the same meaning as ef, but controls the index time / index precision", ParamValidators.gt(0)) - - /** @group getParam */ - final def getEfConstruction: Int = $(efConstruction) - - setDefault(m -> 16, efConstruction -> 200) -} - -/** - * Common params for Hnsw and HnswModel. - */ -private[hnsw] trait HnswModelParams extends KnnModelParams { - - /** - * Size of the dynamic list for the nearest neighbors (used during the search). - * Default: 10 - * - * @group param - */ - final val ef = new IntParam(this, "ef", - "size of the dynamic list for the nearest neighbors (used during the search)", ParamValidators.gt(0)) - - /** @group getParam */ - final def getEf: Int = $(ef) - - setDefault(ef -> 10) -} - - -/** - * Companion class for HnswSimilarityModel. - */ -object HnswSimilarityModel extends MLReadable[HnswSimilarityModel] { - - private[hnsw] class HnswModelReader extends KnnModelReader[HnswSimilarityModel] { - - override protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance : TypeTag - ](uid: String, outputDir: String, numPartitions: Int) - (implicit evId: ClassTag[TId], evVector: ClassTag[TVector], distanceNumeric: Numeric[TDistance]) : HnswSimilarityModel = - new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) - } - - override def read: MLReader[HnswSimilarityModel] = new HnswModelReader - -} - -/** - * Model produced by `HnswSimilarity`. - */ -abstract class HnswSimilarityModel extends KnnModelBase[HnswSimilarityModel] with HnswModelParams with MLWritable { - - /** @group setParam */ - def setEf(value: Int): this.type = set(ef, value) - -} - -private[knn] class HnswSimilarityModelImpl[ - TId : TypeTag, - TVector : TypeTag, - TItem <: Item[TId, TVector] with Product : TypeTag, - TDistance : TypeTag -](override val uid: String, val outputDir: String, numPartitions: Int) - (implicit evId: ClassTag[TId], evVector: ClassTag[TVector], distanceNumeric: Numeric[TDistance]) - extends HnswSimilarityModel with KnnModelOps[HnswSimilarityModel, TId, TVector, TItem, TDistance, HnswIndex[TId, TVector, TItem, TDistance]] { - - override def getNumPartitions: Int = numPartitions - - override def transform(dataset: Dataset[_]): DataFrame = typedTransform(dataset) - - override def copy(extra: ParamMap): HnswSimilarityModel = { - val copied = new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) - copyValues(copied, extra).setParent(parent) - } - - override def transformSchema(schema: StructType): StructType = typedTransformSchema[TId](schema) - - override def write: MLWriter = new KnnModelWriter[HnswSimilarityModel, TId, TVector, TItem, TDistance, HnswIndex[TId, TVector, TItem, TDistance]](this) - - override protected def loadIndex(in: InputStream): HnswIndex[TId, TVector, TItem, TDistance] = { - val index = HnswIndex.loadFromInputStream[TId, TVector, TItem, TDistance](in) - index.ef = getEf - index - } -} - - -/** - * Nearest neighbor search using the approximative hnsw algorithm. - * - * @param uid identifier - */ -class HnswSimilarity(override val uid: String) extends KnnAlgorithm[HnswSimilarityModel](uid) with HnswParams { - - override protected type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] = HnswIndex[TId, TVector, TItem, TDistance] - - def this() = this(Identifiable.randomUID("hnsw")) - - /** @group setParam */ - def setM(value: Int): this.type = set(m, value) - - /** @group setParam */ - def setEf(value: Int): this.type = set(ef, value) - - /** @group setParam */ - def setEfConstruction(value: Int): this.type = set(efConstruction, value) - - override protected def createIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance] - (dimensions: Int, maxItemCount: Int, distanceFunction: DistanceFunction[TVector, TDistance])(implicit distanceOrdering: Ordering[TDistance], idSerializer: knn.ObjectSerializer[TId], itemSerializer: knn.ObjectSerializer[TItem]) - : HnswIndex[TId, TVector, TItem, TDistance] = - HnswIndex[TId, TVector, TItem, TDistance]( - dimensions, - distanceFunction, - maxItemCount, - getM, - getEf, - getEfConstruction, - removeEnabled = false, - idSerializer, - itemSerializer - ) - - override protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance] - (inputStream: InputStream, minCapacity: Int): HnswIndex[TId, TVector, TItem, TDistance] = { - val index = HnswIndex.loadFromInputStream[TId, TVector, TItem, TDistance](inputStream) - index.resize(index.maxItemCount + minCapacity) - index - } - - - override protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance : TypeTag - ](uid: String, outputDir: String, numPartitions: Int) - (implicit evId: ClassTag[TId], evVector: ClassTag[TVector], distanceNumeric: Numeric[TDistance]) : HnswSimilarityModel = - new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) -} - diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala deleted file mode 100644 index d73e089b..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala +++ /dev/null @@ -1,226 +0,0 @@ -package com.github.jelmerk.spark - -import java.io.{ObjectInput, ObjectOutput} - -import com.github.jelmerk.knn.scalalike.ObjectSerializer -import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} - -package object knn { - - private[knn] implicit object StringSerializer extends ObjectSerializer[String] { - override def write(item: String, out: ObjectOutput): Unit = out.writeUTF(item) - override def read(in: ObjectInput): String = in.readUTF() - } - - private[knn] implicit object IntSerializer extends ObjectSerializer[Int] { - override def write(item: Int, out: ObjectOutput): Unit = out.writeInt(item) - override def read(in: ObjectInput): Int = in.readInt() - } - - private[knn] implicit object LongSerializer extends ObjectSerializer[Long] { - override def write(item: Long, out: ObjectOutput): Unit = out.writeLong(item) - override def read(in: ObjectInput): Long = in.readLong() - } - - private[knn] implicit object FloatArraySerializer extends ObjectSerializer[Array[Float]] { - override def write(item: Array[Float], out: ObjectOutput): Unit = { - out.writeInt(item.length) - item.foreach(out.writeFloat) - } - - override def read(in: ObjectInput): Array[Float] = { - val length = in.readInt() - val item = Array.ofDim[Float](length) - - for (i <- 0 until length) { - item(i) = in.readFloat() - } - item - } - } - - private[knn] implicit object DoubleArraySerializer extends ObjectSerializer[Array[Double]] { - override def write(item: Array[Double], out: ObjectOutput): Unit = { - out.writeInt(item.length) - item.foreach(out.writeDouble) - } - - override def read(in: ObjectInput): Array[Double] = { - val length = in.readInt() - val item = Array.ofDim[Double](length) - - for (i <- 0 until length) { - item(i) = in.readDouble() - } - item - } - } - - - private[knn] implicit object VectorSerializer extends ObjectSerializer[Vector] { - override def write(item: Vector, out: ObjectOutput): Unit = item match { - case v: DenseVector => - out.writeBoolean(true) - out.writeInt(v.size) - v.values.foreach(out.writeDouble) - - case v: SparseVector => - out.writeBoolean(false) - out.writeInt(v.size) - out.writeInt(v.indices.length) - v.indices.foreach(out.writeInt) - v.values.foreach(out.writeDouble) - } - - override def read(in: ObjectInput): Vector = { - val isDense = in.readBoolean() - val size = in.readInt() - - if (isDense) { - val values = Array.ofDim[Double](size) - - for (i <- 0 until size) { - values(i) = in.readDouble() - } - - Vectors.dense(values) - } else { - val numFilled = in.readInt() - val indices = Array.ofDim[Int](numFilled) - - for (i <- 0 until numFilled) { - indices(i) = in.readInt() - } - - val values = Array.ofDim[Double](numFilled) - - for (i <- 0 until numFilled) { - values(i) = in.readDouble() - } - - Vectors.sparse(size, indices, values) - } - } - } - - - private[knn] implicit object IntVectorIndexItemSerializer extends ObjectSerializer[IntVectorIndexItem] { - override def write(item: IntVectorIndexItem, out: ObjectOutput): Unit = { - IntSerializer.write(item.id, out) - VectorSerializer.write(item.vector, out) - } - - override def read(in: ObjectInput): IntVectorIndexItem = { - val id = IntSerializer.read(in) - val vector = VectorSerializer.read(in) - IntVectorIndexItem(id, vector) - } - } - - private[knn] implicit object LongVectorIndexItemSerializer extends ObjectSerializer[LongVectorIndexItem] { - override def write(item: LongVectorIndexItem, out: ObjectOutput): Unit = { - LongSerializer.write(item.id, out) - VectorSerializer.write(item.vector, out) - } - - override def read(in: ObjectInput): LongVectorIndexItem = { - val id = LongSerializer.read(in) - val vector = VectorSerializer.read(in) - LongVectorIndexItem(id, vector) - } - } - - private[knn] implicit object StringVectorIndexItemSerializer extends ObjectSerializer[StringVectorIndexItem] { - override def write(item: StringVectorIndexItem, out: ObjectOutput): Unit = { - StringSerializer.write(item.id, out) - VectorSerializer.write(item.vector, out) - } - - override def read(in: ObjectInput): StringVectorIndexItem = { - val id = StringSerializer.read(in) - val vector = VectorSerializer.read(in) - StringVectorIndexItem(id, vector) - } - } - - - private[knn] implicit object IntFloatArrayIndexItemSerializer extends ObjectSerializer[IntFloatArrayIndexItem] { - override def write(item: IntFloatArrayIndexItem, out: ObjectOutput): Unit = { - IntSerializer.write(item.id, out) - FloatArraySerializer.write(item.vector, out) - } - - override def read(in: ObjectInput): IntFloatArrayIndexItem = { - val id = IntSerializer.read(in) - val vector = FloatArraySerializer.read(in) - IntFloatArrayIndexItem(id, vector) - } - } - - private[knn] implicit object LongFloatArrayIndexItemSerializer extends ObjectSerializer[LongFloatArrayIndexItem] { - override def write(item: LongFloatArrayIndexItem, out: ObjectOutput): Unit = { - LongSerializer.write(item.id, out) - FloatArraySerializer.write(item.vector, out) - } - - override def read(in: ObjectInput): LongFloatArrayIndexItem = { - val id = LongSerializer.read(in) - val vector = FloatArraySerializer.read(in) - LongFloatArrayIndexItem(id, vector) - } - } - - private[knn] implicit object StringFloatArrayIndexItemSerializer extends ObjectSerializer[StringFloatArrayIndexItem] { - override def write(item: StringFloatArrayIndexItem, out: ObjectOutput): Unit = { - StringSerializer.write(item.id, out) - FloatArraySerializer.write(item.vector, out) - } - - override def read(in: ObjectInput): StringFloatArrayIndexItem = { - val id = StringSerializer.read(in) - val vector = FloatArraySerializer.read(in) - StringFloatArrayIndexItem(id, vector) - } - } - - - private[knn] implicit object IntDoubleArrayIndexItemSerializer extends ObjectSerializer[IntDoubleArrayIndexItem] { - override def write(item: IntDoubleArrayIndexItem, out: ObjectOutput): Unit = { - IntSerializer.write(item.id, out) - DoubleArraySerializer.write(item.vector, out) - } - - override def read(in: ObjectInput): IntDoubleArrayIndexItem = { - val id = IntSerializer.read(in) - val vector = DoubleArraySerializer.read(in) - IntDoubleArrayIndexItem(id, vector) - } - } - - private[knn] implicit object LongDoubleArrayIndexItemSerializer extends ObjectSerializer[LongDoubleArrayIndexItem] { - override def write(item: LongDoubleArrayIndexItem, out: ObjectOutput): Unit = { - LongSerializer.write(item.id, out) - DoubleArraySerializer.write(item.vector, out) - } - - override def read(in: ObjectInput): LongDoubleArrayIndexItem = { - val id = LongSerializer.read(in) - val vector = DoubleArraySerializer.read(in) - LongDoubleArrayIndexItem(id, vector) - } - } - - private[knn] implicit object StringDoubleArrayIndexItemSerializer extends ObjectSerializer[StringDoubleArrayIndexItem] { - override def write(item: StringDoubleArrayIndexItem, out: ObjectOutput): Unit = { - StringSerializer.write(item.id, out) - DoubleArraySerializer.write(item.vector, out) - } - - override def read(in: ObjectInput): StringDoubleArrayIndexItem = { - val id = StringSerializer.read(in) - val vector = DoubleArraySerializer.read(in) - StringDoubleArrayIndexItem(id, vector) - } - } - -} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/Normalizer.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/Normalizer.scala deleted file mode 100644 index ef35f261..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/Normalizer.scala +++ /dev/null @@ -1,96 +0,0 @@ -package com.github.jelmerk.spark.linalg - -import com.github.jelmerk.knn.util.VectorUtils -import org.apache.spark.internal.Logging -import org.apache.spark.ml.Transformer -import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector} -import org.apache.spark.ml.linalg.SQLDataTypes._ -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.functions._ - -/** - * Companion class for Normalizer. - */ -object Normalizer extends DefaultParamsReadable[Normalizer] { - override def load(path: String): Normalizer = super.load(path) -} - -/** - * Normalizes vectors to unit norm. - * - * @param uid identifier - */ -class Normalizer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Logging with DefaultParamsWritable { - - def this() = this(Identifiable.randomUID("norm")) - - /** @group setParam */ - def setInputCol(value: String): this.type = set(inputCol, value) - - /** @group setParam */ - def setOutputCol(value: String): this.type = set(outputCol, value) - - override def transform(dataset: Dataset[_]): DataFrame = dataset.schema(getInputCol).dataType match { - case VectorType => dataset.withColumn(getOutputCol, normalizeVector(col(getInputCol))) - case ArrayType(FloatType, _) => dataset.withColumn(getOutputCol, normalizeFloatArray(col(getInputCol))) - case ArrayType(DoubleType, _) => dataset.withColumn(getOutputCol, normalizeDoubleArray(col(getInputCol))) - } - - override def copy(extra: ParamMap): Transformer = defaultCopy(extra) - - override def transformSchema(schema: StructType): StructType = { - if (schema.fieldNames.contains(getOutputCol)) { - throw new IllegalArgumentException(s"Output column $getOutputCol already exists.") - } - - if (!schema.fieldNames.contains(getInputCol)) { - throw new IllegalArgumentException(s"Input column $getInputCol does not exist.") - } - - val inputColumnSchema = schema(getInputCol) - - val inputColHasValidDataType = inputColumnSchema.dataType match { - case VectorType => true - case ArrayType(FloatType, _) => true - case ArrayType(DoubleType, _) => true - case _ => false - } - - if (!inputColHasValidDataType) { - throw new IllegalArgumentException(s"Input column $getInputCol must be a float array, double array or vector.") - } - - schema - .add(getOutputCol, inputColumnSchema.dataType, inputColumnSchema.nullable) - } - - private def magnitude(vector: Vector): Double = { - val values = vector match { - case v: SparseVector => v.values - case v: DenseVector => v.values - } - Math.sqrt(values.map(v => v * v).sum) - } - - private val normalizeFloatArray: UserDefinedFunction = udf { value: Seq[Float] => VectorUtils.normalize(value.toArray) } - - private val normalizeDoubleArray: UserDefinedFunction = udf { value: Seq[Double] => VectorUtils.normalize(value.toArray) } - - private val normalizeVector: UserDefinedFunction = udf[Vector, Vector] { value => - val normFactor = 1 / magnitude(value) - - value match { - case v: SparseVector => new SparseVector(v.size, v.indices, v.values.map(_ * normFactor)) - case v: DenseVector => new DenseVector(v.values.map(_ * normFactor)) - } - } - -} - - diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/functions/DenseVectorDistanceFunctions.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/functions/DenseVectorDistanceFunctions.scala deleted file mode 100644 index 0de5e185..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/functions/DenseVectorDistanceFunctions.scala +++ /dev/null @@ -1,78 +0,0 @@ -package com.github.jelmerk.spark.linalg.functions - -import com.github.jelmerk.knn.scalalike._ -import org.apache.spark.ml.linalg.DenseVector - -object DenseVectorDistanceFunctions { - - /** - * Calculates the cosine distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Cosine distance between u and v. - */ - def cosineDistance(u: DenseVector, v: DenseVector): Double = doubleCosineDistance(u.values, v.values) - - /** - * Calculates the inner product. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Inner product between u and v. - */ - def innerProduct(u: DenseVector, v: DenseVector): Double = doubleInnerProduct(u.values, v.values) - - /** - * Calculates the Bray Curtis distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Bray Curtis distance between u and v. - */ - def brayCurtisDistance(u: DenseVector, v: DenseVector): Double = doubleBrayCurtisDistance(u.values, v.values) - - /** - * Calculates the canberra distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Canberra distance between u and v. - */ - def canberraDistance(u: DenseVector, v: DenseVector): Double = doubleCanberraDistance(u.values, v.values) - - /** - * Calculates the correlation distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Correlation distance between u and v. - */ - def correlationDistance(u: DenseVector, v: DenseVector): Double = doubleCorrelationDistance(u.values, v.values) - - /** - * Calculates the euclidean distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Euclidean distance between u and v. - */ - def euclideanDistance(u: DenseVector, v: DenseVector): Double = doubleEuclideanDistance(u.values, v.values) - - /** - * Calculates the manhattan distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Manhattan distance between u and v. - */ - def manhattanDistance(u: DenseVector, v: DenseVector): Double = doubleManhattanDistance(u.values, v.values) - -} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/functions/SparseVectorDistanceFunctions.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/functions/SparseVectorDistanceFunctions.scala deleted file mode 100644 index ff1911b6..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/functions/SparseVectorDistanceFunctions.scala +++ /dev/null @@ -1,339 +0,0 @@ -package com.github.jelmerk.spark.linalg.functions - -import org.apache.spark.ml.linalg.SparseVector -import math.{abs, sqrt, pow} - -object SparseVectorDistanceFunctions { - - /** - * Calculates the cosine distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Cosine distance between u and v. - */ - def cosineDistance(u: SparseVector, v: SparseVector): Double = { - val denom = norm(u) * norm(v) - val dot = innerProduct(u, v) - - if (denom == 0.0) 1d - else 1 - dot / denom - } - - /** - * Calculates the inner product. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Inner product between u and v. - */ - def innerProductDistance(u: SparseVector, v: SparseVector): Double = 1 - innerProduct(u, v) - - /** - * Calculates the Bray Curtis distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Bray Curtis distance between u and v. - */ - def brayCurtisDistance(u: SparseVector, v: SparseVector): Double = { - val uIndices = u.indices - val vIndices = v.indices - - val uValues = u.values - val vValues = v.values - - var sump = 0.0 - var sumn = 0.0 - - var i = 0 - var j = 0 - - while(i < uIndices.length && j < vIndices.length) { - if (uIndices(i) < vIndices(j)) { - val incr = abs(uValues(i)) - - sumn += incr - sump += incr - i += 1 - } else if (uIndices(i) > vIndices(j)) { - val incr = abs(vValues(j)) - - sumn += incr - sump += incr - j += 1 - } else { - sumn += abs(uValues(i) - vValues(j)) - sump += abs(uValues(i) + vValues(j)) - i += 1 - j += 1 - } - } - - while(i < uIndices.length) { - val incr = abs(uValues(i)) - - sumn += incr - sump += incr - i += 1 - } - - while(j < vIndices.length) { - val incr = abs(vValues(j)) - - sumn += incr - sump += incr - j += 1 - } - - sumn / sump - } - - /** - * Calculates the canberra distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Canberra distance between u and v. - */ - def canberraDistance(u: SparseVector, v: SparseVector): Double = { - val uIndices = u.indices - val vIndices = v.indices - - val uValues = u.values - val vValues = v.values - - var distance = 0.0 - - var i = 0 - var j = 0 - - while(i < uIndices.length && j < vIndices.length) { - if (uIndices(i) < vIndices(j)) { - distance += 1 - i += 1 - } else if (uIndices(i) > vIndices(j)) { - distance += 1 - j += 1 - } else { - distance += abs(uValues(i) - vValues(j)) / (abs(uValues(i)) + abs(vValues(j))) - i += 1 - j += 1 - } - } - - distance + (uIndices.length - i) + (vIndices.length - j) - } - - /** - * Calculates the correlation distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Correlation distance between u and v. - */ - def correlationDistance(u: SparseVector, v: SparseVector): Double = { - - val x = -u.values.sum / u.size - val y = -v.values.sum / v.size - - val uIndices = u.indices - val vIndices = v.indices - - val uValues = u.values - val vValues = v.values - - var num = 0.0 - var den1 = 0.0 - var den2 = 0.0 - - var left = u.size - - var i = 0 - var j = 0 - - val absXSquared = abs(pow(x, 2)) - - while(i < uIndices.length && j < vIndices.length) { - if (uIndices(i) < vIndices(j)) { - num += (uValues(i) + x) * y - - den1 += abs(pow(uValues(i) + x, 2)) - den2 += absXSquared - left -= 1 - - i += 1 - } else if (uIndices(i) > vIndices(j)) { - num += x * (vValues(j) + y) - - den1 += absXSquared - den2 += abs(pow(vValues(j) + x, 2)) - left -= 1 - j += 1 - } else { - num += (uValues(i) + x) * (vValues(j) + y) - - den1 += abs(pow(uValues(i) + x, 2)) - den2 += abs(pow(vValues(j) + x, 2)) - left -= 1 - - i += 1 - j += 1 - } - } - - while(i < uIndices.length) { - num += (uValues(i) + x) * y - - den1 += abs(pow(uValues(i) + x, 2)) - den2 += absXSquared - left -= 1 - - i += 1 - } - - while(j < vIndices.length) { - num += x * (vValues(j) + y) - - den1 += absXSquared - den2 += abs(pow(vValues(j) + x, 2)) - - j += 1 - } - - num += (x * y) * left - den1 += absXSquared * left - den2 += absXSquared * left - - 1 - (num / (sqrt(den1) * sqrt(den2))) - } - - /** - * Calculates the euclidean distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Euclidean distance between u and v. - */ - def euclideanDistance(u: SparseVector, v: SparseVector): Double = { - val uIndices = u.indices - val vIndices = v.indices - - val uValues = u.values - val vValues = v.values - - var sum = 0.0 - - var i = 0 - var j = 0 - - while(i < uIndices.length && j < vIndices.length) { - if (uIndices(i) < vIndices(j)) { - sum += pow(uValues(i), 2) - i += 1 - } else if (uIndices(i) > vIndices(j)) { - sum += pow(vValues(j), 2) - j += 1 - } else { - val dp = uValues(i) - vValues(j) - sum += pow(dp, 2) - i += 1 - j += 1 - } - } - - while(i < uIndices.length) { - sum += pow(uValues(i), 2) - i += 1 - } - - while(j < vIndices.length) { - sum += pow(vValues(j), 2) - j += 1 - } - - sqrt(sum) - } - - /** - * Calculates the manhattan distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Manhattan distance between u and v. - */ - def manhattanDistance(u: SparseVector, v: SparseVector): Double = { - val uIndices = u.indices - val vIndices = v.indices - - val uValues = u.values - val vValues = v.values - - var sum = 0.0 - - var i = 0 - var j = 0 - - while(i < uIndices.length && j < vIndices.length) { - if (uIndices(i) < vIndices(j)) { - sum += abs(uValues(i)) - i += 1 - } else if (uIndices(i) > vIndices(j)) { - sum += abs(vValues(j)) - j += 1 - } else { - sum += abs(uValues(i) - vValues(j)) - i += 1 - j += 1 - } - } - - while(i < uIndices.length) { - sum += abs(uValues(i)) - i += 1 - } - - while(j < vIndices.length) { - sum += abs(vValues(j)) - j += 1 - } - - sum - } - - private def norm(u: SparseVector): Double = sqrt(u.values.map(v => v * v).sum) - - private def innerProduct(u: SparseVector, v: SparseVector): Double = { - val uIndices = u.indices - val vIndices = v.indices - - val uValues = u.values - val vValues = v.values - - var dot = 0.0 - - var i = 0 - var j = 0 - - while(i < uIndices.length && j < vIndices.length) { - if (uIndices(i) < vIndices(j)) { - i += 1 - } else if (uIndices(i) > vIndices(j)) { - j += 1 - } else { - dot += uValues(i) * vValues(j) - i += 1 - j += 1 - } - } - dot - } - -} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/functions/VectorDistanceFunctions.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/functions/VectorDistanceFunctions.scala deleted file mode 100644 index 844bbc6f..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/linalg/functions/VectorDistanceFunctions.scala +++ /dev/null @@ -1,111 +0,0 @@ -package com.github.jelmerk.spark.linalg.functions - -import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector} - -object VectorDistanceFunctions { - - /** - * Calculates the cosine distance. - * - * @param u Left vector. - * @param v Right vector - * - * @return Cosine distance between u and v. - */ - def cosineDistance(u: Vector, v: Vector): Double = (u, v) match { - case (ud: DenseVector, vd: DenseVector) => DenseVectorDistanceFunctions.cosineDistance(ud, vd) - case (us: SparseVector, vs: SparseVector) => SparseVectorDistanceFunctions.cosineDistance(us, vs) - case (ud: DenseVector, vs: SparseVector) => SparseVectorDistanceFunctions.cosineDistance(ud.toSparse, vs) - case (us: SparseVector, vd: DenseVector) => SparseVectorDistanceFunctions.cosineDistance(us, vd.toSparse) - } - - /** - * Calculates the inner product. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Inner product between u and v. - */ - def innerProduct(u: Vector, v: Vector): Double = (u, v) match { - case (ud: DenseVector, vd: DenseVector) => DenseVectorDistanceFunctions.innerProduct(ud, vd) - case (us: SparseVector, vs: SparseVector) => SparseVectorDistanceFunctions.innerProductDistance(us, vs) - case (ud: DenseVector, vs: SparseVector) => SparseVectorDistanceFunctions.innerProductDistance(ud.toSparse, vs) - case (us: SparseVector, vd: DenseVector) => SparseVectorDistanceFunctions.innerProductDistance(us, vd.toSparse) - } - - /** - * Calculates the Bray Curtis distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Bray Curtis distance between u and v. - */ - def brayCurtisDistance(u: Vector, v: Vector): Double = (u, v) match { - case (ud: DenseVector, vd: DenseVector) => DenseVectorDistanceFunctions.brayCurtisDistance(ud, vd) - case (us: SparseVector, vs: SparseVector) => SparseVectorDistanceFunctions.brayCurtisDistance(us, vs) - case (us: SparseVector, vd: DenseVector) => SparseVectorDistanceFunctions.brayCurtisDistance(us, vd.toSparse) - case (ud: DenseVector, vs: SparseVector) => SparseVectorDistanceFunctions.brayCurtisDistance(ud.toSparse, vs) - } - - /** - * Calculates the canberra distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Canberra distance between u and v. - */ - def canberraDistance(u: Vector, v: Vector): Double = (u, v) match { - case (ud: DenseVector, vd: DenseVector) => DenseVectorDistanceFunctions.canberraDistance(ud, vd) - case (us: SparseVector, vs: SparseVector) => SparseVectorDistanceFunctions.canberraDistance(us, vs) - case (ud: DenseVector, vs: SparseVector) => SparseVectorDistanceFunctions.canberraDistance(ud.toSparse, vs) - case (us: SparseVector, vd: DenseVector) => SparseVectorDistanceFunctions.canberraDistance(us, vd.toSparse) - } - - /** - * Calculates the correlation distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Correlation distance between u and v. - */ - def correlationDistance(u: Vector, v: Vector): Double = (u, v) match { - case (ud: DenseVector, vd: DenseVector) => DenseVectorDistanceFunctions.correlationDistance(ud, vd) - case (us: SparseVector, vs: SparseVector) => SparseVectorDistanceFunctions.correlationDistance(us, vs) - case (ud: DenseVector, vs: SparseVector) => DenseVectorDistanceFunctions.correlationDistance(ud, vs.toDense) - case (us: SparseVector, vd: DenseVector) => DenseVectorDistanceFunctions.correlationDistance(us.toDense, vd) - } - - /** - * Calculates the euclidean distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Euclidean distance between u and v. - */ - def euclideanDistance(u: Vector, v: Vector): Double = (u, v) match { - case (ud: DenseVector, vd: DenseVector) => DenseVectorDistanceFunctions.euclideanDistance(ud, vd) - case (us: SparseVector, vd: DenseVector) => SparseVectorDistanceFunctions.euclideanDistance(us, vd.toSparse) - case (us: SparseVector, vs: SparseVector) => SparseVectorDistanceFunctions.euclideanDistance(us, vs) - case (ud: DenseVector, vs: SparseVector) => SparseVectorDistanceFunctions.euclideanDistance(ud.toSparse, vs) - } - - /** - * Calculates the manhattan distance. - * - * @param u Left vector. - * @param v Right vector. - * - * @return Manhattan distance between u and v. - */ - def manhattanDistance(u: Vector, v: Vector): Double = (u, v) match { - case (ud: DenseVector, vd: DenseVector) => DenseVectorDistanceFunctions.manhattanDistance(ud, vd) - case (us: SparseVector, vs: SparseVector) => SparseVectorDistanceFunctions.manhattanDistance(us, vs) - case (ud: DenseVector, vs: SparseVector) => SparseVectorDistanceFunctions.manhattanDistance(ud.toSparse.toSparse, vs) - case (us: SparseVector, vd: DenseVector) => SparseVectorDistanceFunctions.manhattanDistance(us, vd.toSparse) - } -} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/util/SerializableConfiguration.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/util/SerializableConfiguration.scala deleted file mode 100644 index bd2543f4..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/util/SerializableConfiguration.scala +++ /dev/null @@ -1,20 +0,0 @@ -package com.github.jelmerk.spark.util - -import java.io.{ObjectInputStream, ObjectOutputStream} - -import org.apache.hadoop.conf.Configuration - -/** - * Copied from org.apache.spark.util.SerializableConfiguration - */ -private[spark] class SerializableConfiguration(@transient var value: Configuration) extends Serializable { - private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { - out.defaultWriteObject() - value.write(out) - } - - private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { - value = new Configuration(false) - value.readFields(in) - } -} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/util/Utils.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/util/Utils.scala deleted file mode 100644 index 0b6fa650..00000000 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/util/Utils.scala +++ /dev/null @@ -1,32 +0,0 @@ -package com.github.jelmerk.spark.util - -import java.io.IOException - -import org.apache.spark.internal.Logging - -import scala.util.control.NonFatal - -/** - * Copied from org.apache.spark.util.Utils - */ -private[spark] object Utils extends Logging { - - /** - * Execute a block of code that returns a value, re-throwing any non-fatal uncaught - * exceptions as IOException. This is used when implementing Externalizable and Serializable's - * read and write methods, since Java's serializer will not report non-IOExceptions properly; - * see SPARK-4080 for more context. - */ - def tryOrIOException[T](block: => T): T = { - try { - block - } catch { - case e: IOException => - logError("Exception encountered", e) - throw e - case NonFatal(e) => - logError("Exception encountered", e) - throw new IOException(e) - } - } -} diff --git a/hnswlib-spark/src/test/python/conftest.py b/hnswlib-spark/src/test/python/conftest.py deleted file mode 100644 index 7dd0cf6e..00000000 --- a/hnswlib-spark/src/test/python/conftest.py +++ /dev/null @@ -1,18 +0,0 @@ -# coding=utf-8 - -import os - -import pytest - -from pyspark.sql import SparkSession - -@pytest.fixture(scope="session", autouse=True) -def spark(request): - sc = SparkSession.builder \ - .config("spark.driver.extraClassPath", os.environ["ARTIFACT_PATH"]) \ - .master("local[*]") \ - .getOrCreate() - - request.addfinalizer(lambda: sc.stop()) - - return sc diff --git a/hnswlib-spark/src/test/python/test_bruteforce.py b/hnswlib-spark/src/test/python/test_bruteforce.py deleted file mode 100644 index e405713d..00000000 --- a/hnswlib-spark/src/test/python/test_bruteforce.py +++ /dev/null @@ -1,23 +0,0 @@ -# coding=utf-8 - -from pyspark_hnsw.knn import BruteForceSimilarity -from pyspark.ml.linalg import Vectors - - -def test_bruteforce(spark): - - df = spark.createDataFrame([ - [1, Vectors.dense([0.2, 0.9])], - [2, Vectors.dense([0.2, 1.0])], - [3, Vectors.dense([0.2, 0.1])], - ], ['row_id', 'features']) - - bruteforce = BruteForceSimilarity(identifierCol='row_id', queryIdentifierCol='row_id', featuresCol='features', - distanceFunction='cosine', numPartitions=100, excludeSelf=False, - similarityThreshold=-1.0) - - model = bruteforce.fit(df) - - result = model.transform(df) - - assert result.count() == 3 diff --git a/hnswlib-spark/src/test/python/test_hnsw.py b/hnswlib-spark/src/test/python/test_hnsw.py deleted file mode 100644 index cb960588..00000000 --- a/hnswlib-spark/src/test/python/test_hnsw.py +++ /dev/null @@ -1,22 +0,0 @@ -# coding=utf-8 - -from pyspark_hnsw.knn import HnswSimilarity -from pyspark.ml.linalg import Vectors - - -def test_hnsw(spark): - - df = spark.createDataFrame([ - [1, Vectors.dense([0.2, 0.9])], - [2, Vectors.dense([0.2, 1.0])], - [3, Vectors.dense([0.2, 0.1])], - ], ['row_id', 'features']) - - hnsw = HnswSimilarity(identifierCol='row_id', featuresCol='features', distanceFunction='cosine', m=32, ef=5, k=5, - efConstruction=200, numPartitions=100, excludeSelf=False, similarityThreshold=-1.0) - - model = hnsw.fit(df) - - result = model.transform(df) - - assert result.count() == 3 diff --git a/hnswlib-spark/src/test/python/test_integration.py b/hnswlib-spark/src/test/python/test_integration.py deleted file mode 100644 index 3e2b9eea..00000000 --- a/hnswlib-spark/src/test/python/test_integration.py +++ /dev/null @@ -1,29 +0,0 @@ -# coding=utf-8 - -from pyspark_hnsw.knn import HnswSimilarity -from pyspark.ml.linalg import Vectors -from pyspark.sql import functions as F - - -def test_incremental_models(spark, tmp_path): - - df1 = spark.createDataFrame([ - [1, Vectors.dense([0.1, 0.2, 0.3])] - ], ['id', 'features']) - - hnsw1 = HnswSimilarity() - - model1 = hnsw1.fit(df1) - - model1.write().overwrite().save(tmp_path.as_posix()) - - df2 = spark.createDataFrame([ - [2, Vectors.dense([0.9, 0.1, 0.2])] - ], ['id', 'features']) - - hnsw2 = HnswSimilarity(initialModelPath=tmp_path.as_posix()) - - model2 = hnsw2.fit(df2) - - assert model2.transform(df1).select(F.explode("prediction")).count() == 2 - diff --git a/hnswlib-spark/src/test/python/test_knn_evaluator.py b/hnswlib-spark/src/test/python/test_knn_evaluator.py deleted file mode 100644 index 6735eaa2..00000000 --- a/hnswlib-spark/src/test/python/test_knn_evaluator.py +++ /dev/null @@ -1,21 +0,0 @@ -# coding=utf-8 - -from pyspark_hnsw.evaluation import KnnSimilarityEvaluator -from pyspark.sql.types import * - - -def test_evaluator(spark): - - neighbors_list_schema = ArrayType(StructType([StructField("neighbor", IntegerType()), StructField("distance", FloatType())])) - - schema = StructType([StructField("approximate", neighbors_list_schema), StructField("exact", neighbors_list_schema)]) - - df = spark.createDataFrame([ - [[{'neighbor': 1, 'distance': 0.1}], [{'neighbor': 1, 'distance': 0.1}]], - [[{'neighbor': 2, 'distance': 0.1}], [{'neighbor': 2, 'distance': 0.1}, {'neighbor': 3, 'distance': 0.9}]] - ], schema=schema) - - evaluator = KnnSimilarityEvaluator(approximateNeighborsCol='approximate', exactNeighborsCol='exact') - - assert evaluator.evaluate(df) == 0.6666666666666666 - diff --git a/hnswlib-spark/src/test/python/test_normalizer.py b/hnswlib-spark/src/test/python/test_normalizer.py deleted file mode 100644 index 61b911e6..00000000 --- a/hnswlib-spark/src/test/python/test_normalizer.py +++ /dev/null @@ -1,14 +0,0 @@ -# coding=utf-8 - -from pyspark_hnsw.linalg import Normalizer -from pyspark.ml.linalg import Vectors - -def test_normalizer(spark): - - df = spark.createDataFrame([[Vectors.dense([0.01, 0.02, 0.03])]], ['vector']) - - normalizer = Normalizer(inputCol="vector", outputCol="normalized_vector") - - result = normalizer.transform(df) - - assert result.count() == 1 diff --git a/hnswlib-spark/src/test/python/test_vector_converter.py b/hnswlib-spark/src/test/python/test_vector_converter.py deleted file mode 100644 index 1f296a39..00000000 --- a/hnswlib-spark/src/test/python/test_vector_converter.py +++ /dev/null @@ -1,14 +0,0 @@ -# coding=utf-8 - -from pyspark_hnsw.conversion import VectorConverter -from pyspark.ml.linalg import Vectors - -def test_vector_converter(spark): - - df = spark.createDataFrame([[Vectors.dense([0.01, 0.02, 0.03])]], ['vector']) - - converter = VectorConverter(inputCol="vector", outputCol="array", outputType="array") - - result = converter.transform(df) - - assert result.count() == 1 diff --git a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/SharedSparkContext.scala b/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/SharedSparkContext.scala deleted file mode 100644 index 7aeebdeb..00000000 --- a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/SharedSparkContext.scala +++ /dev/null @@ -1,43 +0,0 @@ -package com.github.jelmerk.spark - -import org.apache.spark.SparkConf -import org.apache.spark.sql.{SQLContext, SparkSession} -import org.scalatest.{BeforeAndAfterAll, Suite} - -/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */ -trait SharedSparkContext extends BeforeAndAfterAll { - self: Suite => - - @transient private var sparkSession: SparkSession = _ - - def appID: String = this.getClass.getName + math.floor(math.random * 10E4).toLong.toString - - def conf: SparkConf = { - new SparkConf(). - setMaster("local[*]"). - setAppName("test"). - set("spark.ui.enabled", "false"). - set("spark.app.id", appID). - set("spark.driver.host", "localhost") - } - - def spark: SQLContext = sparkSession.sqlContext - - override def beforeAll(): Unit = { - sparkSession = SparkSession.builder().config(conf).getOrCreate() - super.beforeAll() - } - - override def afterAll(): Unit = { - try { - Option(sparkSession).foreach { _.stop() } - // To avoid Akka rebinding to the same port, since it doesn't - // unbind immediately on shutdown. - System.clearProperty("spark.driver.port") - sparkSession = null - } finally { - super.afterAll() - } - } - -} \ No newline at end of file diff --git a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/conversion/VectorConverterSpec.scala b/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/conversion/VectorConverterSpec.scala deleted file mode 100644 index dd89e482..00000000 --- a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/conversion/VectorConverterSpec.scala +++ /dev/null @@ -1,64 +0,0 @@ -package com.github.jelmerk.spark.conversion - -import com.holdenkarau.spark.testing.DataFrameSuiteBase -import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.sql.DataFrame -import org.scalatest.funsuite.AnyFunSuite -import org.scalatest.prop.TableDrivenPropertyChecks._ - -case class InputRow[TVector](vector: TVector) - -case class OutputRow[TVectorIn, TVectorOut](vector: TVectorIn, array: TVectorOut) - -class VectorConverterSpec extends AnyFunSuite with DataFrameSuiteBase { - - test("convert vectors") { - - val sqlCtx = sqlContext - import sqlCtx.implicits._ - - val scenarios = Table[DataFrame, DataFrame, String]( - ("input", "expectedOutput", "outputType"), ( - Seq(InputRow(Vectors.dense(Array(1d, 2d, 3d)))).toDF(), - Seq(OutputRow(Vectors.dense(Array(1d, 2d, 3d)), Array(1f, 2f, 3f))).toDF(), - "array" - ), ( - Seq(InputRow(Array(1d, 2d, 3d))).toDF(), - Seq(OutputRow(Array(1d, 2d, 3d), Array(1f, 2f, 3f))).toDF(), - "array" - ), ( - Seq(InputRow(Vectors.dense(Array(1d, 2d, 3d)))).toDF(), - Seq(OutputRow(Vectors.dense(Array(1d, 2d, 3d)), Array(1d, 2d, 3d))).toDF(), - "array" - ), ( - Seq(InputRow(Array(1f, 2f, 3f))).toDF(), - Seq(OutputRow(Array(1f, 2f, 3f), Array(1d, 2d, 3d))).toDF(), - "array" - ), ( - Seq(InputRow(Array(1f, 2f, 3f))).toDF(), - Seq(OutputRow(Array(1f, 2f, 3f), Vectors.dense(Array(1d, 2d, 3d)))).toDF(), - "vector" - ), ( - Seq(InputRow(Array(1d, 2d, 3d))).toDF(), - Seq(OutputRow(Array(1d, 2d, 3d), Vectors.dense(Array(1d, 2d, 3d)))).toDF(), - "vector" - ) - ) - - val input = Seq(InputRow(Array(1d, 2d, 3d))).toDF() - val expectedOutput = Seq(OutputRow(Array(1d, 2d, 3d), Vectors.dense(Array(1d, 2d, 3d)))).toDF() - val outputType = "vector" - -// forAll (scenarios) { case (input, expectedOutput, outputType) => - - val converter = new VectorConverter() - .setInputCol("vector") - .setOutputCol("array") - .setOutputType(outputType) - -// converter.transform(input).show() - assertDataFrameEquals(converter.transform(input), expectedOutput) -// } - - } -} diff --git a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/SerializerSpec.scala b/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/SerializerSpec.scala deleted file mode 100644 index 22a1d009..00000000 --- a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/SerializerSpec.scala +++ /dev/null @@ -1,57 +0,0 @@ -package com.github.jelmerk.spark.knn - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} - -import com.github.jelmerk.knn.scalalike.{Item, ObjectSerializer} -import org.apache.spark.ml.linalg.{Vectors, Vector} -import org.scalatest.funsuite.AnyFunSuite -import org.scalatest.matchers.should.Matchers._ - -class SerializerSpec extends AnyFunSuite { - - test("serialize objects") { - validateSerializability(IntSerializer, 1) - validateSerializability(LongSerializer, 1L) - validateSerializability(StringSerializer, "1") - validateSerializability(VectorSerializer, Vectors.dense(Array(0.1, 0.2))) - validateSerializability(VectorSerializer, Vectors.sparse(3, Array(0, 1), Array(0.1, 0.2))) - validateSerializability(FloatArraySerializer, Array(0.1f, 0.2f)) - validateSerializability(DoubleArraySerializer, Array(0.1, 0.2)) - validateSerializability(IntDoubleArrayIndexItemSerializer, IntDoubleArrayIndexItem(1, Array(0.1, 0.2)), deepCompare[Int, Array[Double]]) - validateSerializability(LongDoubleArrayIndexItemSerializer, LongDoubleArrayIndexItem(1L, Array(0.1, 0.2)), deepCompare[Long, Array[Double]]) - validateSerializability(StringDoubleArrayIndexItemSerializer, StringDoubleArrayIndexItem("1", Array(0.1, 0.2)), deepCompare[String, Array[Double]]) - - validateSerializability(IntFloatArrayIndexItemSerializer, IntFloatArrayIndexItem(1, Array(0.1f, 0.2f)), deepCompare[Int, Array[Float]]) - validateSerializability(LongFloatArrayIndexItemSerializer, LongFloatArrayIndexItem(1L, Array(0.1f, 0.2f)), deepCompare[Long, Array[Float]]) - validateSerializability(StringFloatArrayIndexItemSerializer, StringFloatArrayIndexItem("1", Array(0.1f, 0.2f)), deepCompare[String, Array[Float]]) - - validateSerializability(IntVectorIndexItemSerializer, IntVectorIndexItem(1, Vectors.dense(0.1, 0.2)), deepCompare[Int, Vector]) - validateSerializability(LongVectorIndexItemSerializer, LongVectorIndexItem(1L, Vectors.dense(0.1, 0.2)), deepCompare[Long, Vector]) - validateSerializability(StringVectorIndexItemSerializer, StringVectorIndexItem("1", Vectors.dense(0.1, 0.2)), deepCompare[String, Vector]) - } - - private def validateSerializability[T](serializer: ObjectSerializer[T], value: T, - validation: (T, T) => Unit = simpleCompare[T] _): Unit = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream(baos) - serializer.write(value, oos) - - oos.flush() - - val bais = new ByteArrayInputStream(baos.toByteArray) - val ois = new ObjectInputStream(bais) - val read = serializer.read(ois) - - validation(read, value) - } - - private def simpleCompare[T](in: T, out: T): Unit = { - in should be (out) - } - - private def deepCompare[TId, TVector](in: Item[TId, TVector], out: Item[TId, TVector]): Unit = { - in.vector should be(out.vector) - in.id should be (out.id) - } - -} diff --git a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/evaluation/KnnSimilarityEvaluatorSpec.scala b/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/evaluation/KnnSimilarityEvaluatorSpec.scala deleted file mode 100644 index 30a3e736..00000000 --- a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/evaluation/KnnSimilarityEvaluatorSpec.scala +++ /dev/null @@ -1,42 +0,0 @@ -package com.github.jelmerk.spark.knn.evaluation - -import com.github.jelmerk.spark.SharedSparkContext -import org.scalatest.funsuite.AnyFunSuite -import org.scalatest.matchers.should.Matchers._ - -case class Neighbor[TId, TDistance](neighbor: TId, distance: TDistance) - -class KnnSimilarityEvaluatorSpec extends AnyFunSuite with SharedSparkContext { - - test("evaluate performance") { - val sqlCtx = spark - import sqlCtx.implicits._ - - val evaluator = new KnnSimilarityEvaluator() - .setApproximateNeighborsCol("approximate") - .setExactNeighborsCol("exact") - - val df = Seq( - Seq(Neighbor("1", 0.1f), Neighbor("2", 0.2f)) -> Seq(Neighbor("1", 0.1f), Neighbor("2", 0.2f)), - Seq(Neighbor("3", 0.1f)) -> Seq(Neighbor("3", 0.1f), Neighbor("4", 0.9f)) - ).toDF("approximate", "exact") - - evaluator.evaluate(df) should be (0.75) - } - - test("evaluate performance empty lists") { - val sqlCtx = spark - import sqlCtx.implicits._ - - val evaluator = new KnnSimilarityEvaluator() - .setApproximateNeighborsCol("approximate") - .setExactNeighborsCol("exact") - - val df = Seq( - Seq.empty[Neighbor[Int, Float]] -> Seq.empty[Neighbor[Int, Float]] - ).toDF("approximate", "exact") - - evaluator.evaluate(df) should be (1) - } - -} diff --git a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilaritySpec.scala b/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilaritySpec.scala deleted file mode 100644 index f2ffa2bf..00000000 --- a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilaritySpec.scala +++ /dev/null @@ -1,255 +0,0 @@ -package com.github.jelmerk.spark.knn.hnsw - -import java.io.File -import java.nio.file.Files -import java.util.UUID - -import com.holdenkarau.spark.testing.DataFrameSuiteBase -import org.apache.commons.io.FileUtils -import org.apache.commons.lang.builder.{EqualsBuilder, HashCodeBuilder} -import org.apache.spark.SparkConf -import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vectors} -import org.apache.spark.sql.DataFrame -import org.scalatest.funsuite.AnyFunSuite -import org.scalatest.matchers.should.Matchers._ -import org.scalatest.prop.TableDrivenPropertyChecks._ - -case class PrePartitionedInputRow[TId, TVector](partition: Int, id: TId, vector: TVector) - -case class QueryRow[TId, TVector](partitions: Seq[Int], id: TId, vector: TVector) - -case class InputRow[TId, TVector](id: TId, vector: TVector) - -case class Neighbor[TId, TDistance](neighbor: TId, distance: TDistance) - -case class FullOutputRow[TId, TVector, TDistance](id: TId, vector: TVector, neighbors: Seq[Neighbor[TId, TDistance]]) { - - // case classes won't work because array equals is implemented as identity equality - override def equals(other: Any): Boolean = EqualsBuilder.reflectionEquals(this, other) - override def hashCode(): Int = HashCodeBuilder.reflectionHashCode(this) -} - -case class MinimalOutputRow[TId, TDistance](id: TId, neighbors: Seq[Neighbor[TId, TDistance]]) { - - // case classes won't work because array equals is implemented as identity equality - override def equals(other: Any): Boolean = EqualsBuilder.reflectionEquals(this, other) - override def hashCode(): Int = HashCodeBuilder.reflectionHashCode(this) -} - -class HnswSimilaritySpec extends AnyFunSuite with DataFrameSuiteBase { - - // for some reason kryo cannot serialize the hnswindex so configure it to make sure it never gets serialized - override def conf: SparkConf = super.conf - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - - test("prepartitioned data") { - - val sqlCtx = sqlContext - import sqlCtx.implicits._ - - val hnsw = new HnswSimilarity() - .setIdentifierCol("id") - .setQueryIdentifierCol("id") - .setFeaturesCol("vector") - .setPartitionCol("partition") - .setQueryPartitionsCol("partitions") - .setNumPartitions(2) - .setNumReplicas(3) - .setK(10) - - val indexItems = Seq( - PrePartitionedInputRow(partition = 0, id = 1000000, vector = Vectors.dense(0.0110, 0.2341)), - PrePartitionedInputRow(partition = 0, id = 2000000, vector = Vectors.dense(0.2300, 0.3891)), - PrePartitionedInputRow(partition = 1, id = 3000000, vector = Vectors.dense(0.4300, 0.9891)) - ).toDF() - - val model = hnsw.fit(indexItems).setPredictionCol("neighbors").setEf(10) - - val queries = Seq( - QueryRow(partitions = Seq(0), id = 123, vector = Vectors.dense(0.2400, 0.3891)) - ).toDF() - - val result = model.transform(queries) - .as[FullOutputRow[Int, DenseVector, Double]] - .collect() - .head - - result.neighbors.size should be (2) // it couldn't see 3000000 because we only query partition 0 - } - - test("find neighbors") { - - val sqlCtx = sqlContext - import sqlCtx.implicits._ - - val noSimilarityThreshold = -1.0 - - val denseVectorInput = Seq( - InputRow(1000000, Vectors.dense(0.0110, 0.2341)), - InputRow(2000000, Vectors.dense(0.2300, 0.3891)), - InputRow(3000000, Vectors.dense(0.4300, 0.9891)) - ).toDF() - - val denseVectorScenarioValidator: DataFrame => Unit = df => { - val rows = df.as[FullOutputRow[Int, DenseVector, Double]].collect() - - rows.find(_.id == 1000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(1000000, 3000000, 2000000)) - rows.find(_.id == 2000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(2000000, 3000000, 1000000)) - rows.find(_.id == 3000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(3000000, 2000000, 1000000)) - } - - val minimalDenseVectorScenarioValidator: DataFrame => Unit = df => { - val rows = df.as[MinimalOutputRow[Int, Double]].collect() - - rows.find(_.id == 1000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(1000000, 3000000, 2000000)) - rows.find(_.id == 2000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(2000000, 3000000, 1000000)) - rows.find(_.id == 3000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(3000000, 2000000, 1000000)) - } - - val sparseVectorInput = Seq( - InputRow(1000000, Vectors.sparse(2, Array(0, 1), Array(0.0110, 0.2341))), - InputRow(2000000, Vectors.sparse(2, Array(0, 1), Array(0.2300, 0.3891))), - InputRow(3000000, Vectors.sparse(2, Array(0, 1), Array(0.4300, 0.9891))) - ).toDF() - - val sparseVectorScenarioValidator: DataFrame => Unit = df => { - val rows = df.as[FullOutputRow[Int, SparseVector, Double]].collect() - - rows.find(_.id == 1000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(3000000, 2000000)) - rows.find(_.id == 2000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(3000000, 1000000)) - rows.find(_.id == 3000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(2000000, 1000000)) - } - - val similarityThresholdScenarioValidator: DataFrame => Unit = df => { - val rows = df.as[FullOutputRow[Int, DenseVector, Double]].collect() - - rows.find(_.id == 1000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(1000000, 3000000)) - rows.find(_.id == 2000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(2000000, 3000000)) - rows.find(_.id == 3000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(3000000, 2000000, 1000000)) - } - - val floatArraySimilarityThresholdScenarioValidator: DataFrame => Unit= df => { - val rows = df.as[FullOutputRow[String, Array[Float], Float]].collect() - - rows.find(_.id == "1000000").toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq("1000000", "3000000")) - rows.find(_.id == "2000000").toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq("2000000", "3000000")) - rows.find(_.id == "3000000").toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq("3000000", "2000000", "1000000")) - } - - val doubleArrayInput = Seq( - InputRow(1000000, Array(0.0110, 0.2341)), - InputRow(2000000, Array(0.2300, 0.3891)), - InputRow(3000000, Array(0.4300, 0.9891)) - ).toDF() - - val doubleArrayScenarioValidator: DataFrame => Unit = df => { - val rows = df.as[FullOutputRow[Int, Array[Double], Double]].collect() - - rows.find(_.id == 1000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(1000000, 3000000, 2000000)) - rows.find(_.id == 2000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(2000000, 3000000, 1000000)) - rows.find(_.id == 3000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(3000000, 2000000, 1000000)) - } - - val floatArrayInput = Seq( - InputRow("1000000", Array(0.0110f, 0.2341f)), - InputRow("2000000", Array(0.2300f, 0.3891f)), - InputRow("3000000", Array(0.4300f, 0.9891f)) - ).toDF() - - val floatArrayScenarioValidator: DataFrame => Unit = df => - df.as[FullOutputRow[String, Array[Float], Float]].collect() should contain only ( - FullOutputRow("2000000", Array(0.2300f, 0.3891f), Seq(Neighbor("2000000", 0.0f), Neighbor("3000000", 0.0076490045f), Neighbor("1000000", 0.11621308f))), - FullOutputRow("3000000", Array(0.4300f, 0.9891f), Seq(Neighbor("3000000", 0.0f), Neighbor("2000000", 0.0076490045f), Neighbor("1000000", 0.06521261f))), - FullOutputRow("1000000", Array(0.0110f, 0.2341f), Seq(Neighbor("1000000", 0.0f), Neighbor("3000000", 0.06521261f), Neighbor("2000000", 0.11621308f))) - ) - - val excludeSelfScenarioValidator: DataFrame => Unit = df => { - val rows = df.as[FullOutputRow[Int, SparseVector, Double]].collect() - - rows.find(_.id == 1000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(3000000, 2000000)) - rows.find(_.id == 2000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(3000000, 1000000)) - rows.find(_.id == 3000000).toSeq.flatMap(_.neighbors.map(_.neighbor)) should be (Seq(2000000, 1000000)) - } - - val scenarios = Table[String, Boolean, Double, DataFrame, DataFrame => Unit]( - ("outputFormat", "excludeSelf", "similarityThreshold", "input", "validator"), - ("full", false, 1, denseVectorInput, denseVectorScenarioValidator), - ("minimal", false, 1, denseVectorInput, minimalDenseVectorScenarioValidator), - ("full", false, 0.1, denseVectorInput, similarityThresholdScenarioValidator), - ("full", false, 0.1, floatArrayInput, floatArraySimilarityThresholdScenarioValidator), - ("full", false, noSimilarityThreshold, doubleArrayInput, doubleArrayScenarioValidator), - ("full", false, noSimilarityThreshold, floatArrayInput, floatArrayScenarioValidator), - ("full", true, noSimilarityThreshold, denseVectorInput, excludeSelfScenarioValidator), - ("full", true, 1, sparseVectorInput, sparseVectorScenarioValidator) - ) - - forAll (scenarios) { case (outputFormat, excludeSelf, similarityThreshold, input, validator) => - - val hnsw = new HnswSimilarity() - .setIdentifierCol("id") - .setQueryIdentifierCol("id") - .setFeaturesCol("vector") - .setNumPartitions(5) - .setNumReplicas(3) - .setK(10) - .setExcludeSelf(excludeSelf) - .setSimilarityThreshold(similarityThreshold) - .setOutputFormat(outputFormat) - - val model = hnsw.fit(input).setPredictionCol("neighbors").setEf(10) - - val result = model.transform(input) - - validator(result) - } - } - - test("save and load model") { - - val sqlCtx = sqlContext - import sqlCtx.implicits._ - - val hnsw = new HnswSimilarity() - .setIdentifierCol("id") - .setQueryIdentifierCol("id") - .setFeaturesCol("vector") - .setPredictionCol("neighbors") - .setOutputFormat("minimal") - - val items = Seq( - InputRow(1000000, Array(0.0110f, 0.2341f)), - InputRow(2000000, Array(0.2300f, 0.3891f)), - InputRow(3000000, Array(0.4300f, 0.9891f)) - ).toDF() - - withTempFolder { folder => - - val path = new File(folder, "model").getCanonicalPath - - hnsw.fit(items).write.overwrite.save(path) - - val model = HnswSimilarityModel.load(path) - - val queryItems = Seq( - InputRow(1000000, Array(0.0110f, 0.2341f)) - ).toDF() - - val results = model.transform(queryItems).as[MinimalOutputRow[Int, Float]].collect() - - results.length should be(1) - results.head should be (MinimalOutputRow(1000000, Seq(Neighbor(1000000, 0.0f), Neighbor(3000000, 0.06521261f), Neighbor(2000000, 0.11621308f)))) - } - - } - - def withTempFolder[T](fn: File => T): T = { - val tempDir = Files.createTempDirectory(UUID.randomUUID().toString).toFile - try { - fn(tempDir) - } finally { - FileUtils.deleteDirectory(tempDir) - } - } - - -} \ No newline at end of file diff --git a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/linalg/NormalizerSpec.scala b/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/linalg/NormalizerSpec.scala deleted file mode 100644 index e9c40698..00000000 --- a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/linalg/NormalizerSpec.scala +++ /dev/null @@ -1,45 +0,0 @@ -package com.github.jelmerk.spark.linalg - -import com.holdenkarau.spark.testing.DataFrameSuiteBase -import org.apache.spark.ml.linalg.{DenseVector, SparseVector} -import org.apache.spark.sql.DataFrame -import org.scalatest.funsuite.AnyFunSuite -import org.scalatest.prop.TableDrivenPropertyChecks._ - -case class InputRow[TVector](vector: TVector) - -case class OutputRow[TVector](vector: TVector, normalized: TVector) - -class NormalizerSpec extends AnyFunSuite with DataFrameSuiteBase { - - test("normalize vector") { - - val sqlCtx = sqlContext - import sqlCtx.implicits._ - - val normalize = new Normalizer() - .setInputCol("vector") - .setOutputCol("normalized") - - val scenarios = Table[DataFrame, DataFrame]( - ("input", "expectedOutput"), - ( - Seq(InputRow(new SparseVector(3, Array(0,1,2), Array(0.01, 0.02, 0.03)))).toDF(), Seq(OutputRow(new SparseVector(3, Array(0,1,2), Array(0.01, 0.02, 0.03)), - new SparseVector(3, Array(0,1,2), Array(0.2672612419124244, 0.5345224838248488, 0.8017837257372731)))).toDF() - ), ( - Seq(InputRow(new DenseVector(Array(0.01, 0.02, 0.03)))).toDF(), Seq(OutputRow(new DenseVector(Array(0.01, 0.02, 0.03)), - new DenseVector(Array(0.2672612419124244, 0.5345224838248488, 0.8017837257372731)))).toDF() - ), ( - Seq(InputRow(Array(0.01, 0.02, 0.03))).toDF(), Seq(OutputRow(Array(0.01, 0.02, 0.03), - Array(0.2672612419124244, 0.5345224838248488, 0.8017837257372731))).toDF() - ), ( - Seq(InputRow(Array(0.01f, 0.02f, 0.03f))).toDF(), Seq(OutputRow(Array(0.01f, 0.02f, 0.03f), Array(0.26726124f, 0.5345225f, 0.8017837f))).toDF() - ) - ) - - forAll (scenarios) { case (input, expectedOutput) => - assertDataFrameEquals(normalize.transform(input), expectedOutput) - } - - } -} diff --git a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/linalg/functions/SparseVectorDistanceFunctionsSpec.scala b/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/linalg/functions/SparseVectorDistanceFunctionsSpec.scala deleted file mode 100644 index 0e7fffdb..00000000 --- a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/linalg/functions/SparseVectorDistanceFunctionsSpec.scala +++ /dev/null @@ -1,51 +0,0 @@ -package com.github.jelmerk.spark.linalg.functions - -import org.apache.spark.ml.linalg.Vector -import org.apache.spark.ml.linalg.Vectors -import org.scalactic.{Equality, TolerantNumerics} -import org.scalatest.funsuite.AnyFunSuite - -import scala.util.Random - -class SparseVectorDistanceFunctionsSpec extends AnyFunSuite { - - private implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.001) - private val random = new Random(1000L) - - - test("produce the same result as dense vector functions") { - - // the dense functions should be well tested so lets just compare the results of the spare functions with the dense counterparts - - for (_ <- 1 to 100) { - - val a = createRandomVector() - val b = createRandomVector() - - assert(DenseVectorDistanceFunctions.innerProduct(a.toDense, b.toDense) === SparseVectorDistanceFunctions.innerProductDistance(a.toSparse, b.toSparse)) - assert(DenseVectorDistanceFunctions.cosineDistance(a.toDense, b.toDense) === SparseVectorDistanceFunctions.cosineDistance(a.toSparse, b.toSparse)) - assert(DenseVectorDistanceFunctions.euclideanDistance(a.toDense, b.toDense) === SparseVectorDistanceFunctions.euclideanDistance(a.toSparse, b.toSparse)) - assert(DenseVectorDistanceFunctions.brayCurtisDistance(a.toDense, b.toDense) === SparseVectorDistanceFunctions.brayCurtisDistance(a.toSparse, b.toSparse)) - assert(DenseVectorDistanceFunctions.canberraDistance(a.toDense, b.toDense) === SparseVectorDistanceFunctions.canberraDistance(a.toSparse, b.toSparse)) - assert(DenseVectorDistanceFunctions.manhattanDistance(a.toDense, b.toDense) === SparseVectorDistanceFunctions.manhattanDistance(a.toSparse, b.toSparse)) - assert(DenseVectorDistanceFunctions.correlationDistance(a.toDense, b.toDense) === SparseVectorDistanceFunctions.correlationDistance(a.toSparse, b.toSparse)) - - } - - } - - def createRandomVector(): Vector = Vectors.dense( - Iterator.continually { - val roll = random.nextInt(10) - - if (roll < 2) 0.0 - else if (roll > 6) random.nextDouble() - else -random.nextDouble() - } - .take(1000) - .toArray - ) - - -} - diff --git a/project/plugins.sbt b/project/plugins.sbt index 9d260c5a..110d137c 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,4 +1,4 @@ -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "1.2.0") addSbtPlugin("com.github.sbt" % "sbt-dynver" % "5.0.1") addSbtPlugin("net.aichler" % "sbt-jupiter-interface" % "0.11.1") -addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.2.1") \ No newline at end of file +addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.2.1") +addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.10.0") \ No newline at end of file diff --git a/scripts/colab_setup.sh b/scripts/colab_setup.sh deleted file mode 100644 index ac890dac..00000000 --- a/scripts/colab_setup.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -#default values for pyspark, pyspark-hnsw, and SPARK_HOME -HNSWLIB="1.1.0" -PYSPARK="3.0.3" - -while getopts s:p: option -do - case "${option}" - in - s) HNSWLIB=${OPTARG};; - p) PYSPARK=${OPTARG};; - esac -done - -echo "setup Colab for PySpark $PYSPARK and Hnswlib $HNSWLIB" -export JAVA_HOME="/usr/lib/jvm/java-11-openjdk-amd64" - -if [[ "$PYSPARK" == "3.1"* ]]; then - echo "Installing PySpark $PYSPARK and Hnswlib $HNSWLIB" -elif [[ "$PYSPARK" == "3.0"* ]]; then - echo "Installing PySpark $PYSPARK and Hnswlib $HNSWLIB" -elif [[ "$PYSPARK" == "2"* ]]; then - echo "Installing PySpark $PYSPARK and Hnswlib $HNSWLIB" - apt-get update - apt-get purge -y openjdk-11* -qq > /dev/null && sudo apt-get autoremove -y -qq > /dev/null - apt-get install -y openjdk-8-jdk-headless -qq > /dev/null - - SPARKHOME="/content/spark-2.4.8-bin-hadoop2.7" - export SPARK_HOME=$SPARKHOME - export JAVA_HOME="/usr/lib/jvm/java-8-openjdk-amd64" - - wget -q "https://downloads.apache.org/spark/spark-2.4.8/spark-2.4.8-bin-hadoop2.7.tgz" > /dev/null - tar -xvf spark-2.4.8-bin-hadoop2.8.tgz > /dev/null - -else - export JAVA_HOME="/usr/lib/jvm/java-11-openjdk-amd64" - PYSPARK="3.0.3" -fi - - -# Install pyspark pyspark-hnsw -! pip install --upgrade -q pyspark==$PYSPARK pyspark-hnsw==$HNSWLIB findspark \ No newline at end of file diff --git a/scripts/kaggle_setup.sh b/scripts/kaggle_setup.sh deleted file mode 100644 index 53aea76c..00000000 --- a/scripts/kaggle_setup.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -#default values for pyspark, pyspark-hnsw, and SPARK_HOME -HNSWLIB="1.1.0" -PYSPARK="3.0.3" - -while getopts s:p: option -do - case "${option}" - in - s) HNSWLIB=${OPTARG};; - p) PYSPARK=${OPTARG};; - esac -done - -echo "setup Kaggle for PySpark $PYSPARK and Hnswlib $HNSWLIB" -export JAVA_HOME="/usr/lib/jvm/java-11-openjdk-amd64" - -if [[ "$PYSPARK" == "3.1"* ]]; then - echo "Installing PySpark $PYSPARK and Hnswlib $HNSWLIB" -elif [[ "$PYSPARK" == "3.0"* ]]; then - echo "Installing PySpark $PYSPARK and Hnswlib $HNSWLIB" -elif [[ "$PYSPARK" == "2"* ]]; then - echo "Installing PySpark $PYSPARK and Hnswlib $HNSWLIB" - apt-get update - apt-get purge -y openjdk-11* -qq > /dev/null && sudo apt-get autoremove -y -qq > /dev/null - apt-get install -y openjdk-8-jdk-headless -qq > /dev/null - - SPARKHOME="/content/spark-2.4.8-bin-hadoop2.7" - export SPARK_HOME=$SPARKHOME - export JAVA_HOME="/usr/lib/jvm/java-8-openjdk-amd64" - - wget -q "https://downloads.apache.org/spark/spark-2.4.8/spark-2.4.8-bin-hadoop2.7.tgz" > /dev/null - tar -xvf spark-2.4.8-bin-hadoop2.8.tgz > /dev/null - -else - export JAVA_HOME="/usr/lib/jvm/java-11-openjdk-amd64" - PYSPARK="3.0.3" -fi -# Install pyspark pyspark-hnsw -! pip install --upgrade -q pyspark==$PYSPARK pyspark-hnsw==$HNSWLIB findspark \ No newline at end of file