Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added get_sequences so command flush works #567

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"Topic :: Database",
"Topic :: Software Development :: Libraries",
],
install_requires=["Django>=1.11", "ordered-set", "psycopg2-binary", "six"],
install_requires=["Django>=1.11", "ordered-set", "psycopg-binary", "six"],
setup_requires=["setuptools-scm"],
use_scm_version=True,
zip_safe=False,
Expand Down
844 changes: 844 additions & 0 deletions tenant_schemas/clone.py

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions tenant_schemas/management/commands/migrate_schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.core.management.commands.migrate import Command as MigrateCommand
from django.db.migrations.exceptions import MigrationSchemaMissing
import django
from tenant_schemas.management.commands import SyncCommon
from tenant_schemas.migration_executors import get_executor
from tenant_schemas.utils import (
Expand All @@ -10,6 +11,9 @@


class Command(SyncCommon):
if django.VERSION >= (3, 1):
# https://github.com/bernardopires/django-tenant-schemas/issues/648#issuecomment-671115840
requires_system_checks = []
help = (
"Updates database schema. Manages both apps with migrations and those without."
)
Expand Down Expand Up @@ -41,7 +45,7 @@ def handle(self, *args, **options):
else:
tenants = (
get_tenant_model()
.objects.exclude(schema_name=get_public_schema_name())
.values_list("schema_name", flat=True)
.objects.exclude(schema_name=get_public_schema_name())
.values_list("schema_name", flat=True)
)
executor.run_migrations(tenants=tenants)
25 changes: 15 additions & 10 deletions tenant_schemas/postgresql_backend/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import re
import warnings
import psycopg2

from django.conf import settings
from django.contrib.contenttypes.models import ContentType
Expand All @@ -10,6 +9,15 @@
from tenant_schemas.utils import get_public_schema_name, get_limit_set_calls
from tenant_schemas.postgresql_backend.introspection import DatabaseSchemaIntrospection

try:
from django.db.backends.postgresql.psycopg_any import is_psycopg3
except ImportError:
is_psycopg3 = False

if is_psycopg3:
import psycopg
else:
import psycopg2 as psycopg

ORIGINAL_BACKEND = getattr(settings, 'ORIGINAL_BACKEND', 'django.db.backends.postgresql_psycopg2')
# Django 1.9+ takes care to rename the default backend to 'django.db.backends.postgresql'
Expand Down Expand Up @@ -131,7 +139,6 @@ def _cursor(self, name=None):
"to call set_schema() or set_tenant()?")
_check_schema_name(self.schema_name)
public_schema_name = get_public_schema_name()
search_paths = []

if self.schema_name == public_schema_name:
search_paths = [public_schema_name]
Expand All @@ -142,25 +149,22 @@ def _cursor(self, name=None):

search_paths.extend(EXTRA_SEARCH_PATHS)

if name:
# Named cursor can only be used once
cursor_for_search_path = self.connection.cursor()
else:
# Reuse
cursor_for_search_path = cursor
# Named cursor can only be used once, just like psycopg3 cursors.
needs_new_cursor = name or is_psycopg3
cursor_for_search_path = self.connection.cursor() if needs_new_cursor else cursor

# In the event that an error already happened in this transaction and we are going
# to rollback we should just ignore database error when setting the search_path
# if the next instruction is not a rollback it will just fail also, so
# we do not have to worry that it's not the good one
try:
cursor_for_search_path.execute('SET search_path = {0}'.format(','.join(search_paths)))
except (django.db.utils.DatabaseError, psycopg2.InternalError):
except (django.db.utils.DatabaseError, psycopg.InternalError):
self.search_path_set = False
else:
self.search_path_set = True

if name:
if needs_new_cursor:
cursor_for_search_path.close()

return cursor
Expand All @@ -171,5 +175,6 @@ class FakeTenant:
We can't import any db model in a backend (apparently?), so this class is used
for wrapping schema names in a tenant-like structure.
"""

def __init__(self, schema_name):
self.schema_name = schema_name
39 changes: 36 additions & 3 deletions tenant_schemas/postgresql_backend/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,19 @@
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo, TableInfo,
)

try:
# Django >= 1.11
from django.db.models.indexes import Index
except ImportError:
Index = None
from django.utils.encoding import force_text

try:
from django.utils.encoding import force_str
except ImportError:
# Django < 4.0
from django.utils.encoding import force_text as force_str


fields = FieldInfo._fields
if 'default' not in fields:
Expand Down Expand Up @@ -174,6 +181,21 @@ class DatabaseSchemaIntrospection(BaseDatabaseIntrospection):
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
"""

_get_sequences_query = """
SELECT s.relname as sequence_name, col.attname
FROM pg_class s
JOIN pg_namespace sn ON sn.oid = s.relnamespace
JOIN pg_depend d ON d.refobjid = s.oid AND d.refclassid='pg_class'::regclass
JOIN pg_attrdef ad ON ad.oid = d.objid AND d.classid = 'pg_attrdef'::regclass
JOIN pg_attribute col ON col.attrelid = ad.adrelid AND col.attnum = ad.adnum
JOIN pg_class tbl ON tbl.oid = ad.adrelid
JOIN pg_namespace n ON n.oid = tbl.relnamespace
WHERE s.relkind = 'S'
AND d.deptype in ('a', 'n')
AND n.nspname = %(schema)s
AND tbl.relname = %(table)s
"""

def get_field_type(self, data_type, description):
field_type = super(DatabaseSchemaIntrospection, self).get_field_type(data_type, description)
if description.default and 'nextval' in description.default:
Expand Down Expand Up @@ -213,9 +235,9 @@ def get_table_description(self, cursor, table_name):

return [
FieldInfo(*(
(force_text(line[0]),) +
(force_str(line[0]),) +
line[1:6] +
(field_map[force_text(line[0])][0] == 'YES', field_map[force_text(line[0])][1])
(field_map[force_str(line[0])][0] == 'YES', field_map[force_str(line[0])][1])
)) for line in cursor.description
]

Expand Down Expand Up @@ -315,3 +337,14 @@ def get_constraints(self, cursor, table_name):
"options": options,
}
return constraints

def get_sequences(self, cursor, table_name, table_fields=()):
sequences = []
cursor.execute(self._get_sequences_query, {
'schema': self.connection.schema_name,
'table': table_name,
})

for row in cursor.fetchall():
sequences.append({'name': row[0], 'table': table_name, 'column': row[1]})
return sequences
20 changes: 20 additions & 0 deletions tenant_schemas/rename.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from django.core.exceptions import ValidationError
from django.db import connection
from tenant_schemas.postgresql_backend.base import _is_valid_schema_name
from tenant_schemas.utils import schema_exists


def rename_schema(*, schema_name, new_schema_name):
"""
This renames a schema to a new name. It checks to see if it exists first
"""
cursor = connection.cursor()

if schema_exists(new_schema_name):
raise ValidationError("New schema name already exists")
if not _is_valid_schema_name(new_schema_name):
raise ValidationError("Invalid string used for the schema name.")

sql = 'ALTER SCHEMA {0} RENAME TO {1}'.format(schema_name, new_schema_name)
cursor.execute(sql)
cursor.close()
6 changes: 5 additions & 1 deletion tenant_schemas/signals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from django.dispatch import Signal

post_schema_sync = Signal(providing_args=['tenant'])
try:
# Django < 4.0
post_schema_sync = Signal(providing_args=['tenant'])
except TypeError:
post_schema_sync = Signal()
post_schema_sync.__doc__ = """
Sent after a tenant has been saved, its schema created and synced
"""
12 changes: 8 additions & 4 deletions tenant_schemas/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class MissingDefaultTenantMiddleware(DefaultTenantMiddleware):
DEFAULT_SCHEMA_NAME = "missing"


def dummy_get_response(request): # pragma: no cover
return None


@unittest.skipIf(six.PY2, "Unexpectedly failing only on Python 2.7")
class RoutesTestCase(BaseTestCase):
@classmethod
Expand All @@ -40,8 +44,8 @@ def setUpClass(cls):
def setUp(self):
super(RoutesTestCase, self).setUp()
self.factory = RequestFactory()
self.tm = TenantMiddleware()
self.dtm = DefaultTenantMiddleware()
self.tm = TenantMiddleware(dummy_get_response)
self.dtm = DefaultTenantMiddleware(dummy_get_response)

self.tenant_domain = "tenant.test.com"
self.tenant = Tenant(domain_url=self.tenant_domain, schema_name="test")
Expand Down Expand Up @@ -84,7 +88,7 @@ def test_non_existent_tenant_to_default_schema_routing(self):

def test_non_existent_tenant_custom_middleware(self):
"""Route unrecognised hostnames to the 'test' tenant."""
dtm = TestDefaultTenantMiddleware()
dtm = TestDefaultTenantMiddleware(dummy_get_response)
request = self.factory.get(
self.url, HTTP_HOST=self.non_existent_tenant.domain_url
)
Expand All @@ -94,7 +98,7 @@ def test_non_existent_tenant_custom_middleware(self):

def test_non_existent_tenant_and_default_custom_middleware(self):
"""Route unrecognised hostnames to the 'missing' tenant."""
dtm = MissingDefaultTenantMiddleware()
dtm = MissingDefaultTenantMiddleware(dummy_get_response)
request = self.factory.get(
self.url, HTTP_HOST=self.non_existent_tenant.domain_url
)
Expand Down