Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paimon-python/pypaimon/ray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

from pypaimon.ray.ray_paimon import read_paimon, write_paimon
from pypaimon.ray.bucket_join import bucket_join
from pypaimon.ray.data_evolution_merge_into import (
WhenMatched,
WhenNotMatched,
Expand All @@ -30,6 +31,7 @@
__all__ = [
"read_paimon",
"write_paimon",
"bucket_join",
"merge_into",
"WhenMatched",
"WhenNotMatched",
Expand Down
139 changes: 139 additions & 0 deletions paimon-python/pypaimon/ray/bucket_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Bucket-aligned join on Ray for two co-bucketed Paimon tables.

Same key -> same bucket on both sides, so each bucket is read and joined in its own
Ray task with no global shuffle -- the no-shuffle alternative to ``ray.data.join``.
"""

from typing import Any, Dict, List, Optional, Sequence, Union

__all__ = ["bucket_join"]

OnSpec = Union[str, Sequence[str]]


def _norm(on: OnSpec) -> List[str]:
return [on] if isinstance(on, str) else list(on)


def _bucketing(table):
count = table.options.bucket()
key = table.options.bucket_key()
return count, ([k.strip() for k in key.split(",")] if key else [])


# Per-process table cache so a worker reuses table metadata across the many buckets
# it runs, instead of reloading the catalog per bucket. Only used when reading given
# splits (snapshot-independent); planning always loads a fresh table.
_TABLE_CACHE: Dict = {}


def _get_table(table_id, catalog_options, use_cache):
from pypaimon.catalog.catalog_factory import CatalogFactory
if not use_cache:
return CatalogFactory.create(catalog_options).get_table(table_id)
key = (table_id, tuple(sorted(catalog_options.items())))
if key not in _TABLE_CACHE:
_TABLE_CACHE[key] = CatalogFactory.create(catalog_options).get_table(table_id)
return _TABLE_CACHE[key]


def _read_builder(table_id, catalog_options, projection, use_cache=False):
rb = _get_table(table_id, catalog_options, use_cache).new_read_builder()
return rb.with_projection(projection) if projection is not None else rb


def _plan_splits_by_bucket(table_id, catalog_options, projection):
"""Plan the manifest once and group splits by bucket (driver-side, fresh snapshot)."""
by_bucket = {}
for s in _read_builder(table_id, catalog_options, projection).new_scan().plan().splits():
by_bucket.setdefault(s.bucket, []).append(s)
return by_bucket


def _read_splits(table_id, catalog_options, projection, splits):
# Reading given splits is snapshot-independent, so the cached table is safe here.
return _read_builder(table_id, catalog_options, projection, use_cache=True).new_read().to_arrow(splits)


def bucket_join(
left: str,
right: str,
catalog_options: Dict[str, str],
*,
on: OnSpec,
left_projection: Optional[List[str]] = None,
right_projection: Optional[List[str]] = None,
join_type: str = "inner",
ray_remote_args: Optional[Dict[str, Any]] = None,
) -> "ray.data.Dataset":
"""Join two co-bucketed tables (same bucket count + bucket-key, joined on the
bucket-key) with no global shuffle. ``on`` must equal the bucket-key. The two
sides must not share column names other than the join key (pyarrow ``join``
would otherwise collide). Returns a ``ray.data.Dataset``."""
import ray
from pypaimon.catalog.catalog_factory import CatalogFactory

on_cols = _norm(on)
cat = CatalogFactory.create(catalog_options)
lcount, lkey = _bucketing(cat.get_table(left))
rcount, rkey = _bucketing(cat.get_table(right))

if not lcount or lcount <= 0 or not rcount or rcount <= 0:
raise ValueError(
"bucket_join requires both tables to be fixed-bucket (bucket > 0); "
f"got {left}={lcount}, {right}={rcount}.")
if lcount != rcount:
raise ValueError(
f"bucket_join requires the same bucket count; {left}={lcount}, {right}={rcount}.")
if lkey != rkey:
raise ValueError(
f"bucket_join requires the same bucket-key; {left}={lkey}, {right}={rkey}.")
if on_cols != lkey:
raise ValueError(
f"bucket_join requires the join key to be the bucket-key {lkey}; got on={on_cols}. "
"Equal keys only co-locate by bucket when joining on the bucket-key.")
if join_type != "inner":
# Outer joins would need the union of buckets (a bucket missing on one side
# still emits rows); only inner is correct with the per-bucket intersection.
raise ValueError(f"bucket_join currently supports only join_type='inner'; got {join_type!r}.")

# Plan each side's manifest once (driver-side, split metadata only -- the join
# results stay distributed below), then dispatch per-bucket splits to the tasks.
left_by_bucket = _plan_splits_by_bucket(left, catalog_options, left_projection)
right_by_bucket = _plan_splits_by_bucket(right, catalog_options, right_projection)

def _join_bucket(left_splits, right_splits):
left_t = _read_splits(left, catalog_options, left_projection, left_splits)
right_t = _read_splits(right, catalog_options, right_projection, right_splits)
return left_t.join(right_t, keys=on_cols, join_type=join_type)

# ``@ray.remote()`` (empty parens) is rejected by Ray, so wrap conditionally.
remote_fn = ray.remote(**ray_remote_args)(_join_bucket) if ray_remote_args else ray.remote(_join_bucket)
# Inner join: only buckets present on both sides can match.
buckets = sorted(set(left_by_bucket) & set(right_by_bucket))
if not buckets:
# No shared bucket: empty result, but keep the join schema (join two empties).
empty = _read_splits(left, catalog_options, left_projection, []).join(
_read_splits(right, catalog_options, right_projection, []),
keys=on_cols, join_type=join_type)
return ray.data.from_arrow(empty)
# Keep each bucket's result as a distributed object ref -- never pulled into the driver.
refs = [remote_fn.remote(left_by_bucket[b], right_by_bucket[b]) for b in buckets]
return ray.data.from_arrow_refs(refs)
158 changes: 158 additions & 0 deletions paimon-python/pypaimon/tests/ray_bucket_join_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import shutil
import tempfile
import unittest

import pyarrow as pa
import pytest

pypaimon = pytest.importorskip("pypaimon")
ray = pytest.importorskip("ray")

from pypaimon import CatalogFactory, Schema
from pypaimon.ray import bucket_join


class RayBucketJoinTest(unittest.TestCase):
"""Bucket-aligned join between two HASH_FIXED tables must equal a global join,
with each bucket joined only against the same bucket (no cross-bucket shuffle)."""

NUM_BUCKETS = 8

@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
cls.catalog_options = {"warehouse": os.path.join(cls.tempdir, "wh")}
cls.catalog = CatalogFactory.create(cls.catalog_options)
cls.catalog.create_database("default", True)
if not ray.is_initialized():
ray.init(ignore_reinit_error=True, num_cpus=4)

@classmethod
def tearDownClass(cls):
try:
if ray.is_initialized():
ray.shutdown()
except Exception:
pass
shutil.rmtree(cls.tempdir, ignore_errors=True)

def _bucketed_table(self, name, schema, key, data):
opts = {"bucket": str(self.NUM_BUCKETS), "bucket-key": key}
self.catalog.create_table(name, Schema.from_pyarrow_schema(schema, options=opts), False)
t = self.catalog.get_table(name)
wb = t.new_batch_write_builder()
w = wb.new_write()
w.write_arrow(data)
wb.new_commit().commit(w.prepare_commit())
w.close()
return name

def _create_bucketed(self, name, schema, key, num_buckets):
opts = {"bucket": str(num_buckets), "bucket-key": key}
self.catalog.create_table(name, Schema.from_pyarrow_schema(schema, options=opts), False)
return name

def test_bucket_join_matches_global_join(self):
loc_schema = pa.schema([("url", pa.string()), ("row_id", pa.int64())])
in_schema = pa.schema([("url", pa.string())])
self._bucketed_table(
"default.locator", loc_schema, "url",
pa.Table.from_pydict({"url": [f"u{i}" for i in range(1000)],
"row_id": list(range(1000))}, schema=loc_schema))
self._bucketed_table(
"default.input", in_schema, "url",
pa.Table.from_pydict({"url": [f"u{i}" for i in range(0, 400)]}, schema=in_schema))

ds = bucket_join(
"default.input", "default.locator", self.catalog_options,
on="url", left_projection=["url"], right_projection=["url", "row_id"])
got = {r["url"]: r["row_id"] for r in ds.take_all()}

# every input url (u0..u399) is matched to its locator row_id
self.assertEqual(set(got), {f"u{i}" for i in range(400)})
self.assertEqual(got["u0"], 0)
self.assertEqual(got["u399"], 399)
self.assertTrue(all(got[f"u{i}"] == i for i in range(400)))

def test_fan_out_one_url_many_row_ids(self):
# A url may map to several locator rows; every match must be emitted.
loc_schema = pa.schema([("url", pa.string()), ("row_id", pa.int64())])
in_schema = pa.schema([("url", pa.string())])
self._bucketed_table(
"default.loc_fan", loc_schema, "url",
pa.Table.from_pydict({"url": ["u0", "u0", "u1"], "row_id": [0, 1, 2]},
schema=loc_schema))
self._bucketed_table(
"default.in_fan", in_schema, "url",
pa.Table.from_pydict({"url": ["u0"]}, schema=in_schema))
ds = bucket_join(
"default.in_fan", "default.loc_fan", self.catalog_options,
on="url", left_projection=["url"], right_projection=["url", "row_id"])
self.assertEqual(sorted(r["row_id"] for r in ds.take_all()), [0, 1])

def test_empty_result_keeps_schema(self):
# No shared bucket -> 0 rows, but the join schema must survive.
loc_schema = pa.schema([("url", pa.string()), ("row_id", pa.int64())])
in_schema = pa.schema([("url", pa.string())])
self._bucketed_table(
"default.loc_empty", loc_schema, "url",
pa.Table.from_pydict({"url": ["u0", "u1"], "row_id": [0, 1]}, schema=loc_schema))
self._bucketed_table(
"default.in_empty", in_schema, "url",
pa.Table.from_pydict({"url": []}, schema=in_schema)) # no rows -> no buckets
ds = bucket_join(
"default.in_empty", "default.loc_empty", self.catalog_options,
on="url", left_projection=["url"], right_projection=["url", "row_id"])
self.assertEqual(ds.count(), 0)
self.assertIn("row_id", ds.schema().names)

def test_rejects_different_bucket_count(self):
sch = pa.schema([("url", pa.string())])
self._create_bucketed("default.cnt_8", sch, "url", 8)
self._create_bucketed("default.cnt_16", sch, "url", 16)
with self.assertRaises(ValueError):
bucket_join("default.cnt_8", "default.cnt_16", self.catalog_options, on="url")

def test_rejects_different_bucket_key(self):
sch = pa.schema([("url", pa.string()), ("k", pa.string())])
self._create_bucketed("default.by_url", sch, "url", 8)
self._create_bucketed("default.by_k", sch, "k", 8)
with self.assertRaises(ValueError):
bucket_join("default.by_url", "default.by_k", self.catalog_options, on="url")

def test_rejects_join_key_not_bucket_key(self):
sch = pa.schema([("url", pa.string()), ("k", pa.string())])
self._create_bucketed("default.k1", sch, "url", 8)
self._create_bucketed("default.k2", sch, "url", 8)
with self.assertRaises(ValueError): # on=k but bucket-key=url
bucket_join("default.k1", "default.k2", self.catalog_options, on="k")

def test_rejects_non_inner_join(self):
sch = pa.schema([("url", pa.string())])
self._create_bucketed("default.ji1", sch, "url", 8)
self._create_bucketed("default.ji2", sch, "url", 8)
with self.assertRaises(ValueError):
bucket_join("default.ji1", "default.ji2", self.catalog_options,
on="url", join_type="left outer")


if __name__ == "__main__":
unittest.main()
Loading