Skip to content

Commit

Permalink
Add support for psycopg and asyncpg drivers
Browse files Browse the repository at this point in the history
This introduces the `crate+psycopg://`, `crate+asyncpg://`, and
`crate+urllib3://` dialect identifiers. The asynchronous variant of
`psycopg` is also supported.
  • Loading branch information
amotl committed Jun 25, 2024
1 parent 6db4702 commit e8bfd77
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@


## Unreleased
- Added support for `psycopg` and `asyncpg` drivers, by introducing the
`crate+psycopg://`, `crate+asyncpg://`, and `crate+urllib3://` dialect
identifiers. The asynchronous variant of `psycopg` is also supported.

## 2024/06/25 0.38.0
- Added/reactivated documentation as `sqlalchemy-cratedb`
Expand Down
12 changes: 10 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ dependencies = [
]
[project.optional-dependencies]
all = [
"sqlalchemy-cratedb[vector]",
"sqlalchemy-cratedb[postgresql,vector]",
]
develop = [
"black<25",
Expand All @@ -107,6 +107,9 @@ doc = [
"crate-docs-theme>=0.26.5",
"sphinx<8,>=3.5",
]
postgresql = [
"sqlalchemy-postgresql-relaxed",
]
release = [
"build<2",
"twine<6",
Expand All @@ -117,6 +120,7 @@ test = [
"pandas<2.3",
"pueblo>=0.0.7",
"pytest<9",
"pytest-asyncio<0.24",
"pytest-cov<6",
"pytest-mock<4",
]
Expand All @@ -129,7 +133,11 @@ documentation = "https://cratedb.com/docs/sqlalchemy-cratedb/"
homepage = "https://cratedb.com/docs/sqlalchemy-cratedb/"
repository = "https://github.com/crate/sqlalchemy-cratedb"
[project.entry-points."sqlalchemy.dialects"]
crate = "sqlalchemy_cratedb:dialect"
"crate" = "sqlalchemy_cratedb:dialect"
"crate.urllib3" = "sqlalchemy_cratedb.dialect_more:dialect_urllib3"
"crate.psycopg" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg"
"crate.psycopg_async" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg_async"
"crate.asyncpg" = "sqlalchemy_cratedb.dialect_more:dialect_asyncpg"

[tool.black]
line-length = 100
Expand Down
42 changes: 37 additions & 5 deletions src/sqlalchemy_cratedb/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import logging
from datetime import datetime, date
from types import ModuleType

from sqlalchemy import types as sqltypes
from sqlalchemy.engine import default, reflection
Expand Down Expand Up @@ -202,6 +203,12 @@ def initialize(self, connection):
self.default_schema_name = \
self._get_default_schema_name(connection)

def set_isolation_level(self, dbapi_connection, level):
"""
For CrateDB, this is implemented as a noop.
"""
pass

def do_rollback(self, connection):
# if any exception is raised by the dbapi, sqlalchemy by default
# attempts to do a rollback crate doesn't support rollbacks.
Expand All @@ -220,7 +227,21 @@ def connect(self, host=None, port=None, *args, **kwargs):
use_ssl = asbool(kwargs.pop("ssl", False))
if use_ssl:
servers = ["https://" + server for server in servers]
return self.dbapi.connect(servers=servers, **kwargs)

is_module = isinstance(self.dbapi, ModuleType)
if is_module:
driver_name = self.dbapi.__name__
else:
driver_name = self.dbapi.__class__.__name__
if driver_name == "crate.client":
if "database" in kwargs:
del kwargs["database"]

Check warning on line 238 in src/sqlalchemy_cratedb/dialect.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect.py#L238

Added line #L238 was not covered by tests
return self.dbapi.connect(servers=servers, **kwargs)
elif driver_name in ["psycopg", "PsycopgAdaptDBAPI", "AsyncAdapt_asyncpg_dbapi"]:
return self.dbapi.connect(host=host, port=port, **kwargs)
else:
raise ValueError(f"Unknown driver variant: {driver_name}")

Check warning on line 243 in src/sqlalchemy_cratedb/dialect.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect.py#L243

Added line #L243 was not covered by tests

return self.dbapi.connect(**kwargs)

def _get_default_schema_name(self, connection):
Expand Down Expand Up @@ -266,11 +287,11 @@ def get_schema_names(self, connection, **kw):
def get_table_names(self, connection, schema=None, **kw):
if schema is None:
schema = self._get_effective_schema_name(connection)
cursor = connection.exec_driver_sql(
cursor = connection.exec_driver_sql(self._format_query(
"SELECT table_name FROM information_schema.tables "
"WHERE {0} = ? "
"AND table_type = 'BASE TABLE' "
"ORDER BY table_name ASC, {0} ASC".format(self.schema_column),
"ORDER BY table_name ASC, {0} ASC").format(self.schema_column),
(schema or self.default_schema_name, )
)
return [row[0] for row in cursor.fetchall()]
Expand All @@ -292,7 +313,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
"AND column_name !~ ?" \
.format(self.schema_column)
cursor = connection.exec_driver_sql(
query,
self._format_query(query),
(table_name,
schema or self.default_schema_name,
r"(.*)\[\'(.*)\'\]") # regex to filter subscript
Expand Down Expand Up @@ -331,7 +352,7 @@ def result_fun(result):
return set(rows[0] if rows else [])

pk_result = engine.exec_driver_sql(
query,
self._format_query(query),
(table_name, schema or self.default_schema_name)
)
pks = result_fun(pk_result)
Expand Down Expand Up @@ -372,6 +393,17 @@ def has_ilike_operator(self):
server_version_info = self.server_version_info
return server_version_info is not None and server_version_info >= (4, 1, 0)

def _format_query(self, query):
"""
When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`,
the paramstyle is not `qmark`, but `pyformat`.
TODO: Review: Is it legit and sane? Are there alternatives?
"""
if self.paramstyle == "pyformat":
query = query.replace("= ?", "= %s").replace("!~ ?", "!~ %s")

Check warning on line 404 in src/sqlalchemy_cratedb/dialect.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect.py#L404

Added line #L404 was not covered by tests
return query


class DateTrunc(functions.GenericFunction):
name = "date_trunc"
Expand Down
106 changes: 106 additions & 0 deletions src/sqlalchemy_cratedb/dialect_more.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# -*- coding: utf-8; -*-
#
# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
# license agreements. See the NOTICE file distributed with this work for
# additional information regarding copyright ownership. Crate licenses
# this file to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may
# obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# However, if you have executed another commercial license agreement
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy_postgresql_relaxed.asyncpg import PGDialect_asyncpg_relaxed
from sqlalchemy_postgresql_relaxed.base import PGDialect_relaxed
from sqlalchemy_postgresql_relaxed.psycopg import (
PGDialect_psycopg_relaxed,
PGDialectAsync_psycopg_relaxed,
)

from sqlalchemy_cratedb import dialect


class CrateDialectPostgresAdapter(PGDialect_relaxed, dialect):
"""
Provide a dialect on top of the relaxed PostgreSQL dialect.
"""

inspector = Inspector

# Need to manually override some methods because of polymorphic inheritance woes.
# TODO: Investigate if this can be solved using metaprogramming or other techniques.
has_schema = dialect.has_schema
has_table = dialect.has_table
get_schema_names = dialect.get_schema_names
get_table_names = dialect.get_table_names
get_view_names = dialect.get_view_names
get_columns = dialect.get_columns
get_pk_constraint = dialect.get_pk_constraint
get_foreign_keys = dialect.get_foreign_keys
get_indexes = dialect.get_indexes

get_multi_columns = dialect.get_multi_columns
get_multi_pk_constraint = dialect.get_multi_pk_constraint
get_multi_foreign_keys = dialect.get_multi_foreign_keys

# TODO: Those may want to go to dialect instead?
def get_multi_indexes(self, *args, **kwargs):
return []

Check warning on line 57 in src/sqlalchemy_cratedb/dialect_more.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect_more.py#L57

Added line #L57 was not covered by tests

def get_multi_unique_constraints(self, *args, **kwargs):
return []

Check warning on line 60 in src/sqlalchemy_cratedb/dialect_more.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect_more.py#L60

Added line #L60 was not covered by tests

def get_multi_check_constraints(self, *args, **kwargs):
return []

Check warning on line 63 in src/sqlalchemy_cratedb/dialect_more.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect_more.py#L63

Added line #L63 was not covered by tests

def get_multi_table_comment(self, *args, **kwargs):
return []

Check warning on line 66 in src/sqlalchemy_cratedb/dialect_more.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect_more.py#L66

Added line #L66 was not covered by tests


class CrateDialect_psycopg(PGDialect_psycopg_relaxed, CrateDialectPostgresAdapter):
driver = "psycopg"

@classmethod
def get_async_dialect_cls(cls, url):
return CrateDialectAsync_psycopg

@classmethod
def import_dbapi(cls):
import psycopg

return psycopg


class CrateDialectAsync_psycopg(PGDialectAsync_psycopg_relaxed, CrateDialectPostgresAdapter):
driver = "psycopg_async"
is_async = True


class CrateDialect_asyncpg(PGDialect_asyncpg_relaxed, CrateDialectPostgresAdapter):
driver = "asyncpg"

# TODO: asyncpg may have `paramstyle="numeric_dollar"`. Review this!

# TODO: AttributeError: module 'asyncpg' has no attribute 'paramstyle'
"""
@classmethod
def import_dbapi(cls):
import asyncpg
return asyncpg
"""


dialect_urllib3 = dialect
dialect_psycopg = CrateDialect_psycopg
dialect_psycopg_async = CrateDialectAsync_psycopg
dialect_asyncpg = CrateDialect_asyncpg
81 changes: 81 additions & 0 deletions tests/engine_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import pytest
import sqlalchemy as sa
from sqlalchemy.dialects import registry as dialect_registry

from sqlalchemy_cratedb import SA_VERSION, SA_2_0

if SA_VERSION < SA_2_0:
raise pytest.skip("Only supported on SQLAlchemy 2.0 and higher", allow_module_level=True)

from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine

# Registering the additional dialects manually seems to be needed when running
# under tests. Apparently, manual registration is not needed under regular
# circumstances, as this is wired through the `sqlalchemy.dialects` entrypoint
# registrations in `pyproject.toml`. It is definitively weird, but c'est la vie.
dialect_registry.register("crate.urllib3", "sqlalchemy_cratedb.dialect_more", "dialect_urllib3")
dialect_registry.register("crate.asyncpg", "sqlalchemy_cratedb.dialect_more", "dialect_asyncpg")
dialect_registry.register("crate.psycopg", "sqlalchemy_cratedb.dialect_more", "dialect_psycopg")


QUERY = sa.text("SELECT mountain, coordinates FROM sys.summits ORDER BY mountain LIMIT 3;")


def test_engine_sync_vanilla():
"""
crate:// -- Verify connectivity and data transport with vanilla HTTP-based driver.
"""
engine = sa.create_engine("crate://crate@localhost:4200/", echo=True)
assert isinstance(engine, sa.engine.Engine)
with engine.connect() as connection:
result = connection.execute(QUERY)
assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': [10.95667, 47.18917]}


def test_engine_sync_urllib3():
"""
crate+urllib3:// -- Verify connectivity and data transport *explicitly* selecting the HTTP driver.
"""
engine = sa.create_engine("crate+urllib3://crate@localhost:4200/", isolation_level="AUTOCOMMIT", echo=True)
assert isinstance(engine, sa.engine.Engine)
with engine.connect() as connection:
result = connection.execute(QUERY)
assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': [10.95667, 47.18917]}


def test_engine_sync_psycopg():
"""
crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3).
"""
engine = sa.create_engine("crate+psycopg://crate@localhost:5432/", isolation_level="AUTOCOMMIT", echo=True)
assert isinstance(engine, sa.engine.Engine)
with engine.connect() as connection:
result = connection.execute(QUERY)
assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': '(10.95667,47.18917)'}


@pytest.mark.asyncio
async def test_engine_async_psycopg():
"""
crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3).
This time, in asynchronous mode.
"""
engine = create_async_engine("crate+psycopg://crate@localhost:5432/", isolation_level="AUTOCOMMIT", echo=True)
assert isinstance(engine, AsyncEngine)
async with engine.begin() as conn:
result = await conn.execute(QUERY)
assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': '(10.95667,47.18917)'}


@pytest.mark.asyncio
async def test_engine_async_asyncpg():
"""
crate+asyncpg:// -- Verify connectivity and data transport using the asyncpg driver.
This exclusively uses asynchronous mode.
"""
from asyncpg.pgproto.types import Point
engine = create_async_engine("crate+asyncpg://crate@localhost:5432/", isolation_level="AUTOCOMMIT", echo=True)
assert isinstance(engine, AsyncEngine)
async with engine.begin() as conn:
result = await conn.execute(QUERY)
assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': Point(10.95667, 47.18917)}

0 comments on commit e8bfd77

Please sign in to comment.