Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/datajoint/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init_subclass__(cls, *, register: bool = True, **kwargs):
_codec_registry[cls.name] = cls()
logger.debug(f"Registered codec <{cls.name}> from {cls.__module__}.{cls.__name__}")

@abstractmethod
def get_dtype(self, is_external: bool) -> str:
"""
Return the storage dtype for this codec.
Expand All @@ -136,12 +137,10 @@ def get_dtype(self, is_external: bool) -> str:

Raises
------
NotImplementedError
If not overridden by subclass.
DataJointError
If external storage not supported but requested.
"""
raise NotImplementedError(f"Codec <{self.name}> must implement get_dtype()")
...

@abstractmethod
def encode(self, value: Any, *, key: dict | None = None, store_name: str | None = None) -> Any:
Expand Down
21 changes: 17 additions & 4 deletions src/datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,10 +704,12 @@ def _insert_rows(self, rows, replace, skip_duplicates, ignore_extra_fields):
rows = list(self.__make_row_to_insert(row, field_list, ignore_extra_fields) for row in rows)
if rows:
try:
query = "{command} INTO {destination}(`{fields}`) VALUES {placeholders}{duplicate}".format(
# Handle empty field_list (all-defaults insert)
fields_clause = f"(`{'`,`'.join(field_list)}`)" if field_list else "()"
query = "{command} INTO {destination}{fields} VALUES {placeholders}{duplicate}".format(
command="REPLACE" if replace else "INSERT",
destination=self.from_clause(),
fields="`,`".join(field_list),
fields=fields_clause,
placeholders=",".join("(" + ",".join(row["placeholders"]) + ")" for row in rows),
duplicate=(
" ON DUPLICATE KEY UPDATE `{pk}`=`{pk}`".format(pk=self.primary_key[0]) if skip_duplicates else ""
Expand Down Expand Up @@ -1239,8 +1241,19 @@ def check_fields(fields):
if ignore_extra_fields:
attributes = [a for a in attributes if a is not None]

assert len(attributes), "Empty tuple"
row_to_insert = dict(zip(("names", "placeholders", "values"), zip(*attributes)))
if not attributes:
# Check if empty insert is allowed (all attributes have defaults)
required_attrs = [
attr.name
for attr in self.heading.attributes.values()
if not (attr.autoincrement or attr.nullable or attr.default is not None)
]
if required_attrs:
raise DataJointError(f"Cannot insert empty row. The following attributes require values: {required_attrs}")
# All attributes have defaults - allow empty insert
row_to_insert = {"names": (), "placeholders": (), "values": ()}
else:
row_to_insert = dict(zip(("names", "placeholders", "values"), zip(*attributes)))
if not field_list:
# first row sets the composition of the field list
field_list.extend(row_to_insert["names"])
Expand Down
69 changes: 69 additions & 0 deletions tests/integration/test_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,72 @@ def test_validation_result_summary_truncated(self, schema_insert):
result = dj.ValidationResult(is_valid=False, errors=errors, rows_checked=20)
summary = result.summary()
assert "and 10 more errors" in summary


class AllDefaultsTable(dj.Manual):
"""Table where all attributes have defaults."""

definition = """
id : int auto_increment
---
timestamp=CURRENT_TIMESTAMP : datetime
notes=null : varchar(200)
"""


class TestEmptyInsert:
"""Tests for inserting empty dicts (GitHub issue #1280)."""

@pytest.fixture
def schema_empty_insert(self, connection_test, prefix):
schema = dj.Schema(
prefix + "_empty_insert_test",
context=dict(AllDefaultsTable=AllDefaultsTable, SimpleTable=SimpleTable),
connection=connection_test,
)
schema(AllDefaultsTable)
schema(SimpleTable)
yield schema
schema.drop()

def test_empty_insert_all_defaults(self, schema_empty_insert):
"""Test that empty insert succeeds when all attributes have defaults."""
table = AllDefaultsTable()
assert len(table) == 0

# Insert empty dict - should use all defaults
table.insert1({})
assert len(table) == 1

# Check that values were populated with defaults
row = table.fetch1()
assert row["id"] == 1 # auto_increment starts at 1
assert row["timestamp"] is not None # CURRENT_TIMESTAMP
assert row["notes"] is None # nullable defaults to NULL

def test_empty_insert_multiple(self, schema_empty_insert):
"""Test inserting multiple empty dicts."""
table = AllDefaultsTable()

# Insert multiple empty dicts
table.insert([{}, {}, {}])
assert len(table) == 3

# Each should have unique auto_increment id
ids = set(table.to_arrays("id"))
assert ids == {1, 2, 3}

def test_empty_insert_required_fields_error(self, schema_empty_insert):
"""Test that empty insert raises clear error when fields are required."""
table = SimpleTable()

# SimpleTable has required fields (id, value)
with pytest.raises(dj.DataJointError) as exc_info:
table.insert1({})

error_msg = str(exc_info.value)
assert "Cannot insert empty row" in error_msg
assert "require values" in error_msg
# Should list the required attributes
assert "id" in error_msg
assert "value" in error_msg
Loading