Skip to content

Commit 8d8b2c8

Browse files
committed
support db_default
1 parent 2c7586a commit 8d8b2c8

File tree

3 files changed

+91
-6
lines changed

3 files changed

+91
-6
lines changed

django_singlestore/features.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
335335
supports_unlimited_charfield = False
336336

337337
# Does the backend support database-side default values?
338-
supports_db_default = False
338+
supports_db_default = True
339339

340340
supports_update_conflicts_with_target = False
341341

@@ -878,6 +878,9 @@ def django_test_expected_failures(self):
878878
# Auto increment fields must have BIGINT data type . default is BigAutoField
879879
"introspection.tests.IntrospectionTests.test_get_table_description_types",
880880
"introspection.tests.IntrospectionTests.test_smallautofield",
881+
# db_default parameter does no support complex functions.
882+
"field_defaults.tests.DefaultTests.test_case_when_db_default_no_returning",
883+
"migrations.test_operations.OperationTests.test_add_field_database_default_function",
881884
}
882885

883886
return fails

django_singlestore/schema.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
44
from django.db.models import Manager
55
from django.db.models import NOT_PROVIDED
6+
from django.db.models.sql import Query
67

78

89
class ModelStorageManager(Manager):
@@ -67,6 +68,25 @@ def quote_value(self, value):
6768
def prepare_default(self, value):
6869
return self.quote_value(value)
6970

71+
def db_default_sql(self, field):
72+
"""Return the sql and params for the field's database default."""
73+
from django.db.models.expressions import Value
74+
75+
db_default = field._db_default_expression
76+
sql = (
77+
self._column_default_sql(field) if isinstance(db_default, Value) else "%s"
78+
)
79+
query = Query(model=field.model)
80+
compiler = query.get_compiler(connection=self.connection)
81+
default_sql, params = compiler.compile(db_default)
82+
if self.connection.features.requires_literal_defaults:
83+
# Some databases don't support parameterized defaults (Oracle,
84+
# SQLite). If this is the case, the individual schema backend
85+
# should implement prepare_default().
86+
default_sql %= tuple(self.prepare_default(p) for p in params)
87+
params = []
88+
return sql % default_sql, params
89+
7090
def column_sql(self, model, field, include_default=False):
7191
"""
7292
Return the column definition for a field. The field must already have
@@ -85,9 +105,30 @@ def column_sql(self, model, field, include_default=False):
85105
result_sql_parts.append(self._collate_sql(collation))
86106
if self.connection.features.supports_comments_inline and field.db_comment:
87107
result_sql_parts.append(self._comment_sql(field.db_comment))
88-
# Include a default value, if requested.
89-
include_default = include_default and not field.null
90-
if include_default:
108+
109+
# Handle db_default (database-level default)
110+
if hasattr(field, 'db_default') and field.db_default is not NOT_PROVIDED:
111+
# db_default takes precedence over regular default when creating columns
112+
default_sql, default_params = self.db_default_sql(field)
113+
result_sql_parts.append("DEFAULT " + default_sql)
114+
params.extend(default_params)
115+
# if hasattr(field.db_default, 'resolve_expression'):
116+
# # It's a database function (like Now())
117+
# compiler = self.connection.ops.compiler('SQLCompiler')(
118+
# query=None, connection=self.connection, using=None
119+
# )
120+
# db_default_sql = field.db_default.as_sql(compiler, self.connection)[0]
121+
# result_sql_parts.append("DEFAULT " + db_default_sql)
122+
# else:
123+
# # It's a literal value
124+
# db_default_sql = self._column_default_sql(field)
125+
# if self.connection.features.requires_literal_defaults:
126+
# result_sql_parts.append("DEFAULT " + db_default_sql % self.prepare_default(field.db_default))
127+
# else:
128+
# result_sql_parts.append("DEFAULT " + db_default_sql)
129+
# params.append(field.db_default)
130+
# Include a regular default value, if requested and no db_default is set.
131+
elif include_default and not field.null:
91132
default_value = self.effective_default(field)
92133
if default_value is not None:
93134
column_default = "DEFAULT " + self._column_default_sql(field)
@@ -221,8 +262,19 @@ def add_field(self, model, field):
221262
super().add_field(model, field)
222263

223264
# Simulate the effect of a one-off default.
265+
if hasattr(field, 'db_default') and field.db_default is not NOT_PROVIDED:
266+
default_sql, default_params = self.db_default_sql(field)
267+
self.execute(
268+
"UPDATE %(table)s SET %(column)s = %(default_value)s"
269+
% {
270+
"table": self.quote_name(model._meta.db_table),
271+
"column": self.quote_name(field.column),
272+
"default_value": default_sql,
273+
},
274+
default_params,
275+
)
224276
# field.default may be unhashable, so a set isn't used for "in" check.
225-
if field.default not in (None, NOT_PROVIDED):
277+
elif field.default not in (None, NOT_PROVIDED):
226278
effective_default = self.effective_default(field)
227279
self.execute(
228280
"UPDATE %(table)s SET %(column)s = %%s"
@@ -271,3 +323,33 @@ def _alter_column_type_sql(
271323
return super()._alter_column_type_sql(
272324
model, old_field, new_field, new_type, old_collation, new_collation,
273325
)
326+
327+
def _alter_column_null_sql(self, model, old_field, new_field):
328+
if new_field.db_default is NOT_PROVIDED:
329+
return super()._alter_column_null_sql(model, old_field, new_field)
330+
331+
new_db_params = new_field.db_parameters(connection=self.connection)
332+
type_sql = self._set_field_new_type(new_field, new_db_params["type"])
333+
return (
334+
"MODIFY %(column)s %(type)s"
335+
% {
336+
"column": self.quote_name(new_field.column),
337+
"type": type_sql,
338+
},
339+
[],
340+
)
341+
342+
def _set_field_new_type(self, field, new_type):
343+
"""
344+
Keep the NULL and DEFAULT properties of the old field. If it has
345+
changed, it will be handled separately.
346+
"""
347+
if field.db_default is not NOT_PROVIDED:
348+
default_sql, params = self.db_default_sql(field)
349+
default_sql %= tuple(self.quote_value(p) for p in params)
350+
new_type += f" DEFAULT {default_sql}"
351+
if field.null:
352+
new_type += " NULL"
353+
else:
354+
new_type += " NOT NULL"
355+
return new_type

scripts/get_test_matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
def get_test_modules(tests_root):
99
subdirs_to_skip = {
1010
"import_error_package", "test_runner_apps", "__pycache__", "requirements",
11-
"gis_tests", "postgres_tests", "distinct_on_fields", "templates", "field_defaults",
11+
"gis_tests", "postgres_tests", "distinct_on_fields", "templates",
1212
}
1313
test_modules = []
1414
for item in os.listdir(tests_root):

0 commit comments

Comments
 (0)