Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VNDLY-42402: Implement get_sequences in our fork #6

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tenant_schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
default_app_config = 'tenant_schemas.apps.TenantSchemaConfig'

__version__ = "v1.9.0-vndly-0.0.4"
__version__ = "v1.9.0-vndly-0.0.5"
23 changes: 14 additions & 9 deletions tenant_schemas/postgresql_backend/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import re
import warnings
import psycopg2

from django.conf import settings
from django.contrib.contenttypes.models import ContentType
Expand All @@ -10,6 +9,15 @@
from tenant_schemas.utils import get_public_schema_name, get_limit_set_calls
from tenant_schemas.postgresql_backend.introspection import DatabaseSchemaIntrospection

try:
from django.db.backends.postgresql.psycopg_any import is_psycopg3
except ImportError:
is_psycopg3 = False

if is_psycopg3:
import psycopg
else:
import psycopg2 as psycopg

ORIGINAL_BACKEND = getattr(settings, 'ORIGINAL_BACKEND', 'django.db.backends.postgresql_psycopg2')
# Django 1.9+ takes care to rename the default backend to 'django.db.backends.postgresql'
Expand Down Expand Up @@ -142,25 +150,22 @@ def _cursor(self, name=None):

search_paths.extend(EXTRA_SEARCH_PATHS)

if name:
# Named cursor can only be used once
cursor_for_search_path = self.connection.cursor()
else:
# Reuse
cursor_for_search_path = cursor
# Named cursor can only be used once, just like psycopg3 cursors.
needs_new_cursor = name or is_psycopg3
cursor_for_search_path = self.connection.cursor() if needs_new_cursor else cursor

# In the event that an error already happened in this transaction and we are going
# to rollback we should just ignore database error when setting the search_path
# if the next instruction is not a rollback it will just fail also, so
# we do not have to worry that it's not the good one
try:
cursor_for_search_path.execute('SET search_path = {0}'.format(','.join(search_paths)))
except (django.db.utils.DatabaseError, psycopg2.InternalError):
except (django.db.utils.DatabaseError, psycopg.InternalError):
self.search_path_set = False
else:
self.search_path_set = True

if name:
if needs_new_cursor:
cursor_for_search_path.close()

return cursor
Expand Down
26 changes: 26 additions & 0 deletions tenant_schemas/postgresql_backend/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,21 @@ class DatabaseSchemaIntrospection(BaseDatabaseIntrospection):
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
"""

_get_sequences_query = """
SELECT s.relname as sequence_name, col.attname
FROM pg_class s
JOIN pg_namespace sn ON sn.oid = s.relnamespace
JOIN pg_depend d ON d.refobjid = s.oid AND d.refclassid='pg_class'::regclass
JOIN pg_attrdef ad ON ad.oid = d.objid AND d.classid = 'pg_attrdef'::regclass
JOIN pg_attribute col ON col.attrelid = ad.adrelid AND col.attnum = ad.adnum
JOIN pg_class tbl ON tbl.oid = ad.adrelid
JOIN pg_namespace n ON n.oid = tbl.relnamespace
WHERE s.relkind = 'S'
AND d.deptype in ('a', 'n')
AND n.nspname = %(schema)s
AND tbl.relname = %(table)s
"""

def get_field_type(self, data_type, description):
field_type = super(DatabaseSchemaIntrospection, self).get_field_type(data_type, description)
if description.default and 'nextval' in description.default:
Expand Down Expand Up @@ -315,3 +330,14 @@ def get_constraints(self, cursor, table_name):
"options": options,
}
return constraints

def get_sequences(self, cursor, table_name, table_fields=()):
sequences = []
cursor.execute(self._get_sequences_query, {
'schema': self.connection.schema_name,
'table': table_name,
})

for row in cursor.fetchall():
sequences.append({'name': row[0], 'table': table_name, 'column': row[1]})
return sequences