diff --git a/tenant_schemas/management/commands/migrate_schemas.py b/tenant_schemas/management/commands/migrate_schemas.py index 3dc9d87b..346fd627 100644 --- a/tenant_schemas/management/commands/migrate_schemas.py +++ b/tenant_schemas/management/commands/migrate_schemas.py @@ -1,5 +1,10 @@ +import math +import django +import argparse + from django.core.management.commands.migrate import Command as MigrateCommand from django.db.migrations.exceptions import MigrationSchemaMissing + from tenant_schemas.management.commands import SyncCommon from tenant_schemas.migration_executors import get_executor from tenant_schemas.utils import ( @@ -9,6 +14,24 @@ ) +def chunks(tenants, total_parts): + """ + Iterates over tenants, returning each part, one at a time + """ + tenants_per_chunk = int(math.ceil(float(len(tenants)) / total_parts)) + for i in range(0, len(tenants), tenants_per_chunk): + yield tenants[i:i + tenants_per_chunk] + + +def greater_than_zero(astring): + if not astring.isdigit(): + raise argparse.ArgumentTypeError('Needs to be a number') + number = int(astring) + if not number > 0: + raise argparse.ArgumentTypeError('The number needs to be greated than zero') + return number + + class Command(SyncCommon): help = ( "Updates database schema. Manages both apps with migrations and those without." @@ -18,9 +41,25 @@ def add_arguments(self, parser): super(Command, self).add_arguments(parser) command = MigrateCommand() command.add_arguments(parser) + parser.add_argument('--part', action='store', dest='part', type=greater_than_zero, default=None, + help=('The part you want to process from the pieces (requires --of). ' + 'Example: --part 2 --of 3')) + parser.add_argument('--of', action='store', dest='total_parts', type=greater_than_zero, default=None, + help=('Splits the tenant schemas into specified number of pieces of equal size to ' + 'then be proccessed in parts (requires --part). Example: --part 2 --of 3')) def handle(self, *args, **options): super(Command, self).handle(*args, **options) + + required_together = (self.options['total_parts'], self.options['part'],) + if any(required_together) and not all(required_together): + raise Exception("--part and --of need to be used together.") + elif all(required_together): + if self.options['part'] > self.options['total_parts']: + raise Exception("--of cannot be greater than --part.") + elif self.sync_public: + raise Exception("Cannot run public schema migrations along with --of and --part.") + self.PUBLIC_SCHEMA_NAME = get_public_schema_name() executor = get_executor(codename=self.executor)(self.args, self.options) @@ -42,6 +81,15 @@ def handle(self, *args, **options): tenants = ( get_tenant_model() .objects.exclude(schema_name=get_public_schema_name()) - .values_list("schema_name", flat=True) + .values_list('schema_name', flat=True).order_by('pk') ) + if self.options['total_parts'] and tenants: + tenant_parts = list(chunks(tenants, self.options['total_parts'])) + try: + tenants = tenant_parts[self.options['part'] - 1] + except IndexError: + message = 'You have fewer tenants than parts. This part (%s) has nothing to do.\n' + self.stdout.write(message % self.options['part']) + return + executor.run_migrations(tenants=tenants) diff --git a/tenant_schemas/tests/test_tenants.py b/tenant_schemas/tests/test_tenants.py index 42aa9e73..a8246aa3 100644 --- a/tenant_schemas/tests/test_tenants.py +++ b/tenant_schemas/tests/test_tenants.py @@ -1,7 +1,11 @@ +from mock import patch + from django.conf import settings from django.contrib.auth.models import User from django.db import connection from dts_test_app.models import DummyModel, ModelWithFkToPublicUser + +from tenant_schemas.management.commands.migrate_schemas import greater_than_zero from tenant_schemas.management.commands import tenant_command from tenant_schemas.test.cases import TenantTestCase from tenant_schemas.tests.models import NonAutoSyncTenant, Tenant @@ -411,3 +415,116 @@ def test_tenant_survives_after_method1(self): def test_tenant_survives_after_method2(self): # The same tenant still exists even after the previous method call self.assertEquals(1, get_tenant_model().objects.all().count()) + + +class MigrateSchemasTest(BaseTestCase, TenantTestCase): + + def test_simple_options_command(self): + get_tenant_model_mock = patch('tenant_schemas.management.commands.migrate_schemas.get_tenant_model') + get_executor_mock = patch('tenant_schemas.management.commands.migrate_schemas.get_executor') + + with get_tenant_model_mock as get_tenant_model, get_executor_mock as get_executor: + run_migrations = get_executor.return_value.return_value.run_migrations + query_obj = get_tenant_model.return_value.objects.exclude.return_value.order_by.return_value.values_list + query_obj.return_value = ['a', 'b', 'c'] + + call_command('migrate_schemas', + tenant=True, + part=1, + of=3) + + query_obj.assert_called_once() + run_migrations.assert_called_once_with(tenants=['a']) + + call_command('migrate_schemas', + tenant=True, + part=2, + of=3) + + run_migrations.assert_called_with(tenants=['b']) + + call_command('migrate_schemas', + tenant=True, + part=3, + of=3) + + run_migrations.assert_called_with(tenants=['c']) + + with self.assertRaises(Exception) as context: + call_command('migrate_schemas', + tenant=True, + part=4, + of=3) + + self.assertIn('--of cannot be greater than --part.', str(context.exception)) + + def test_rounding_command(self): + get_tenant_model_mock = patch('tenant_schemas.management.commands.migrate_schemas.get_tenant_model') + get_executor_mock = patch('tenant_schemas.management.commands.migrate_schemas.get_executor') + + with get_tenant_model_mock as get_tenant_model, get_executor_mock as get_executor: + run_migrations = get_executor.return_value.return_value.run_migrations + query_obj = get_tenant_model.return_value.objects.exclude.return_value.order_by.return_value.values_list + query_obj.return_value = ['a'] + + call_command('migrate_schemas', + tenant=True, + part=1, + of=2) + + run_migrations.assert_called_with(tenants=['a']) + + out = StringIO() + call_command('migrate_schemas', + tenant=True, + part=2, + of=2, + stdout=out) + self.assertIn('You have fewer tenants than parts', out.getvalue()) + + def test_rounding2_command(self): + get_tenant_model_mock = patch('tenant_schemas.management.commands.migrate_schemas.get_tenant_model') + get_executor_mock = patch('tenant_schemas.management.commands.migrate_schemas.get_executor') + + with get_tenant_model_mock as get_tenant_model, get_executor_mock as get_executor: + run_migrations = get_executor.return_value.return_value.run_migrations + query_obj = get_tenant_model.return_value.objects.exclude.return_value.order_by.return_value.values_list + query_obj.return_value = list(range(53)) + + call_command('migrate_schemas', + tenant=True, + part=1, + of=7) + + run_migrations.assert_called_with(tenants=[0, 1, 2, 3, 4, 5, 6, 7]) + + call_command('migrate_schemas', + tenant=True, + part=7, + of=7) + + run_migrations.assert_called_with(tenants=[48, 49, 50, 51, 52]) + + def test_errors_command(self): + + with self.assertRaises(Exception) as context: + call_command('migrate_schemas', + part=1, + of=1) + self.assertIn('Cannot run public schema migrations along with --of and --part.', str(context.exception)) + + with self.assertRaises(Exception) as context: + call_command('migrate_schemas', + of=1) + self.assertIn('need to be used together', str(context.exception)) + + # Note: Before django 2.0, django bypasses the the "type=" checker in the argparse completely, thats why + # we check the underlining function directly. + greater_than_zero('1') + with self.assertRaises(argparse.ArgumentTypeError) as context: + greater_than_zero('0') + self.assertIn('The number needs to be greated than zero', str(context.exception)) + + with self.assertRaises(argparse.ArgumentTypeError) as context: + greater_than_zero('0a') + self.assertIn('Needs to be a number', str(context.exception))