Skip to content

Commit

Permalink
Add support for REAL_VECTOR utility functions
Browse files Browse the repository at this point in the history
Closes #325
  • Loading branch information
kasium committed Oct 8, 2024
1 parent 3a45bef commit fb63b9c
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Please note, that only the following modules are considered to be part of the pu

- ``sqlalchemy_hana.types``
- ``sqlalchemy_hana.errors``
- ``sqlalchemy_hana.elements``
- ``sqlalchemy_hana.functions``

For these, only exported members (part of ``__all__`` ) are guaranteed to be stable.

Expand Down Expand Up @@ -231,6 +233,9 @@ For proper typing, the ``REAL_VECTOR`` class is generic and be set to the proper
Please note, that the generic type and ``vector_output_type`` should be kept in sync; this is not
enforced.

The ``sqlalchemy_hana.functions`` package defines certain utility functions like
``cosine_similarity``.

Regex
~~~~~
sqlalchemy-hana supports the ``regexp_match`` and ``regexp_replace``
Expand Down
38 changes: 38 additions & 0 deletions sqlalchemy_hana/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# pylint: disable=invalid-name
"""Custom SQL functions for SAP HANA."""

from __future__ import annotations

from typing import Generic

from sqlalchemy import Float, Integer
from sqlalchemy.sql.functions import GenericFunction

from sqlalchemy_hana.types import _RV, REAL_VECTOR


class cardinality(GenericFunction[int]):
type = Integer()
inherit_cache = True
_has_args = True


class cosine_similarity(GenericFunction[float]):
type = Float()
inherit_cache = True
_has_args = True


class l2distance(GenericFunction[float]):
type = Float()
inherit_cache = True
_has_args = True


class to_real_vector(GenericFunction[_RV], Generic[_RV]):
type = REAL_VECTOR[_RV]()
inherit_cache = True
_has_args = True


__all__ = ("cardinality", "cosine_similarity", "l2distance", "to_real_vector")
37 changes: 37 additions & 0 deletions test/test_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""SQL function tests."""

from __future__ import annotations

from sqlalchemy import select
from sqlalchemy.testing.assertions import eq_
from sqlalchemy.testing.fixtures import TestBase

from sqlalchemy_hana.functions import (
cardinality,
cosine_similarity,
l2distance,
to_real_vector,
)


class FunctionTest(TestBase):

def test_cardinality(self, connection):
res = connection.execute(select(cardinality(to_real_vector("[1, 2, 3]")))).one()
eq_(res, (3,))

def test_cosine_similarity(self, connection):
res = connection.execute(
select(
cosine_similarity(
to_real_vector("[1, 0, 0]"), to_real_vector("[0.5, 0.8660254, 0]")
)
)
).one()
eq_(res, (0.5,))

def test_l2distance(self, connection):
res = connection.execute(
select(l2distance(to_real_vector("[2, 3, 5]"), to_real_vector("[6, 6, 5]")))
).one()
eq_(res, (5,))

0 comments on commit fb63b9c

Please sign in to comment.