Skip to content

Commit

Permalink
ENH: Check for column name collisions
Browse files Browse the repository at this point in the history
  • Loading branch information
genematx committed Mar 10, 2025
1 parent 9069911 commit 76560c2
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 16 deletions.
34 changes: 34 additions & 0 deletions tiled/catalog/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,40 @@ async def read(self, *args, **kwargs):
return self
return await ensure_awaitable((await self.get_adapter()).read, *args, **kwargs)

async def create_node(
self,
structure_family,
metadata,
key=None,
specs=None,
data_sources=None,
):
key = key or self.context.key_maker()

# Check for column name collisions in a Composite node
if self.structure_family == StructureFamily.composite:
assert len(data_sources) == 1
if data_sources[0].structure_family == StructureFamily.table:
new_keys = data_sources[0].structure.columns
elif data_sources[0].structure_family in [StructureFamily.array, StructureFamily.awkward, StructureFamily.sparse]:
new_keys = [key]
else:
raise ValueError(f"Unsupported structure family: {data_sources[0].structure_family}")

# Get all keys and columns names in the Composite node
flat_keys = []
for _key, item in await self.items_range(offset=0, limit=None):
flat_keys.append(_key)
if item.structure_family == StructureFamily.table:
flat_keys.extend(item.structure().columns)

key_conflicts = set(new_keys) & set(flat_keys)
if key_conflicts:
raise Collision(f"Column name collision: {key_conflicts}")

return await super().create_node(
structure_family, metadata, key=key, specs=specs, data_sources=data_sources
)

class CatalogArrayAdapter(CatalogNodeAdapter):
async def read(self, *args, **kwargs):
Expand Down
82 changes: 66 additions & 16 deletions tiled/client/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,8 +1052,9 @@ def write_dataframe(


class Composite(Container):

@property
def _flat_keys_mapping(self, maxlen=None):
def _contents(self, maxlen=None):
result = {}
next_page_url = f"{self.item['links']['search']}"
while (next_page_url is not None) or (
Expand All @@ -1067,25 +1068,31 @@ def _flat_keys_mapping(self, maxlen=None):
**parse_qs(urlparse(next_page_url).query),
**self._queries_as_params,
"select_metadata": False,
"omit_links": True,
},
)
).json()
self._cached_len = (
content["meta"]["count"],
time.monotonic() + LENGTH_CACHE_TTL,
)
for item in content["data"]:
if item["attributes"]["structure_family"] == StructureFamily.table:
for col in item["attributes"]["structure"]["columns"]:
result[col] = item["id"] + "/" + col
else:
result[item["id"]] = item["id"]
result.update({item['id'] : item for item in content["data"]})

next_page_url = content["links"]["next"]

return result

@property
def _flat_keys_mapping(self):
result = {}
for key, item in self._contents.items():
if item["attributes"]["structure_family"] == StructureFamily.table:
for col in item["attributes"]["structure"]["columns"]:
result[col] = item["id"] + "/" + col
else:
result[item["id"]] = item["id"]

return result

@property
def parts(self):
return CompositeContents(self)

def _keys_slice(self, start, stop, direction, _ignore_inlined_contents=False):
yield from self._flat_keys_mapping.keys()

Expand All @@ -1096,11 +1103,54 @@ def _items_slice(self, start, stop, direction, _ignore_inlined_contents=False):
def __len__(self):
return len(self._flat_keys_mapping)

def __getitem__(self, keys, _ignore_inlined_contents=False):
if keys in self._flat_keys_mapping:
keys = self._flat_keys_mapping[keys]
def __getitem__(self, key: str, _ignore_inlined_contents=False):
if key in self._flat_keys_mapping:
key = self._flat_keys_mapping[key]
else:
raise KeyError(key)

return super().__getitem__(key, _ignore_inlined_contents)


class CompositeContents:
def __init__(self, node):
self._contents = node._contents
self.context = node.context
self.structure_clients = node.structure_clients
self._include_data_sources = node._include_data_sources

def __repr__(self):
return (
f"<{type(self).__name__} {{"
+ ", ".join(f"'{item}'" for item in self._contents)
+ "}>"
)

def __getitem__(self, key):
key, *tail = key.split("/")

if key not in self._contents:
raise KeyError(key)

client = client_for_item(
self.context,
self.structure_clients,
self._contents[key],
include_data_sources=self._include_data_sources,
)

if tail:
return client['/'.join(tail)]
else:
return client


def __iter__(self):
for key in self._contents:
yield key

return super().__getitem__(keys, _ignore_inlined_contents)
def __len__(self) -> int:
return len(self._contents)


def _queries_to_params(*queries):
Expand Down

0 comments on commit 76560c2

Please sign in to comment.