diff --git a/paimon-python/pypaimon/ray/__init__.py b/paimon-python/pypaimon/ray/__init__.py index 4280187956e3..ba88fe5c7aba 100644 --- a/paimon-python/pypaimon/ray/__init__.py +++ b/paimon-python/pypaimon/ray/__init__.py @@ -16,6 +16,7 @@ # under the License. from pypaimon.ray.ray_paimon import read_paimon, write_paimon +from pypaimon.ray.bucket_join import bucket_join from pypaimon.ray.data_evolution_merge_into import ( WhenMatched, WhenNotMatched, @@ -30,6 +31,7 @@ __all__ = [ "read_paimon", "write_paimon", + "bucket_join", "merge_into", "WhenMatched", "WhenNotMatched", diff --git a/paimon-python/pypaimon/ray/bucket_join.py b/paimon-python/pypaimon/ray/bucket_join.py new file mode 100644 index 000000000000..8fc1019c5c3c --- /dev/null +++ b/paimon-python/pypaimon/ray/bucket_join.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Bucket-aligned join on Ray for two co-bucketed Paimon tables. + +Same key -> same bucket on both sides, so each bucket is read and joined in its own +Ray task with no global shuffle -- the no-shuffle alternative to ``ray.data.join``. +""" + +from typing import Any, Dict, List, Optional, Sequence, Union + +__all__ = ["bucket_join"] + +OnSpec = Union[str, Sequence[str]] + + +def _norm(on: OnSpec) -> List[str]: + return [on] if isinstance(on, str) else list(on) + + +def _bucketing(table): + count = table.options.bucket() + key = table.options.bucket_key() + return count, ([k.strip() for k in key.split(",")] if key else []) + + +# Per-process table cache so a worker reuses table metadata across the many buckets +# it runs, instead of reloading the catalog per bucket. Only used when reading given +# splits (snapshot-independent); planning always loads a fresh table. +_TABLE_CACHE: Dict = {} + + +def _get_table(table_id, catalog_options, use_cache): + from pypaimon.catalog.catalog_factory import CatalogFactory + if not use_cache: + return CatalogFactory.create(catalog_options).get_table(table_id) + key = (table_id, tuple(sorted(catalog_options.items()))) + if key not in _TABLE_CACHE: + _TABLE_CACHE[key] = CatalogFactory.create(catalog_options).get_table(table_id) + return _TABLE_CACHE[key] + + +def _read_builder(table_id, catalog_options, projection, use_cache=False): + rb = _get_table(table_id, catalog_options, use_cache).new_read_builder() + return rb.with_projection(projection) if projection is not None else rb + + +def _plan_splits_by_bucket(table_id, catalog_options, projection): + """Plan the manifest once and group splits by bucket (driver-side, fresh snapshot).""" + by_bucket = {} + for s in _read_builder(table_id, catalog_options, projection).new_scan().plan().splits(): + by_bucket.setdefault(s.bucket, []).append(s) + return by_bucket + + +def _read_splits(table_id, catalog_options, projection, splits): + # Reading given splits is snapshot-independent, so the cached table is safe here. + return _read_builder(table_id, catalog_options, projection, use_cache=True).new_read().to_arrow(splits) + + +def bucket_join( + left: str, + right: str, + catalog_options: Dict[str, str], + *, + on: OnSpec, + left_projection: Optional[List[str]] = None, + right_projection: Optional[List[str]] = None, + join_type: str = "inner", + ray_remote_args: Optional[Dict[str, Any]] = None, +) -> "ray.data.Dataset": + """Join two co-bucketed tables (same bucket count + bucket-key, joined on the + bucket-key) with no global shuffle. ``on`` must equal the bucket-key. The two + sides must not share column names other than the join key (pyarrow ``join`` + would otherwise collide). Returns a ``ray.data.Dataset``.""" + import ray + from pypaimon.catalog.catalog_factory import CatalogFactory + + on_cols = _norm(on) + cat = CatalogFactory.create(catalog_options) + lcount, lkey = _bucketing(cat.get_table(left)) + rcount, rkey = _bucketing(cat.get_table(right)) + + if not lcount or lcount <= 0 or not rcount or rcount <= 0: + raise ValueError( + "bucket_join requires both tables to be fixed-bucket (bucket > 0); " + f"got {left}={lcount}, {right}={rcount}.") + if lcount != rcount: + raise ValueError( + f"bucket_join requires the same bucket count; {left}={lcount}, {right}={rcount}.") + if lkey != rkey: + raise ValueError( + f"bucket_join requires the same bucket-key; {left}={lkey}, {right}={rkey}.") + if on_cols != lkey: + raise ValueError( + f"bucket_join requires the join key to be the bucket-key {lkey}; got on={on_cols}. " + "Equal keys only co-locate by bucket when joining on the bucket-key.") + if join_type != "inner": + # Outer joins would need the union of buckets (a bucket missing on one side + # still emits rows); only inner is correct with the per-bucket intersection. + raise ValueError(f"bucket_join currently supports only join_type='inner'; got {join_type!r}.") + + # Plan each side's manifest once (driver-side, split metadata only -- the join + # results stay distributed below), then dispatch per-bucket splits to the tasks. + left_by_bucket = _plan_splits_by_bucket(left, catalog_options, left_projection) + right_by_bucket = _plan_splits_by_bucket(right, catalog_options, right_projection) + + def _join_bucket(left_splits, right_splits): + left_t = _read_splits(left, catalog_options, left_projection, left_splits) + right_t = _read_splits(right, catalog_options, right_projection, right_splits) + return left_t.join(right_t, keys=on_cols, join_type=join_type) + + # ``@ray.remote()`` (empty parens) is rejected by Ray, so wrap conditionally. + remote_fn = ray.remote(**ray_remote_args)(_join_bucket) if ray_remote_args else ray.remote(_join_bucket) + # Inner join: only buckets present on both sides can match. + buckets = sorted(set(left_by_bucket) & set(right_by_bucket)) + if not buckets: + # No shared bucket: empty result, but keep the join schema (join two empties). + empty = _read_splits(left, catalog_options, left_projection, []).join( + _read_splits(right, catalog_options, right_projection, []), + keys=on_cols, join_type=join_type) + return ray.data.from_arrow(empty) + # Keep each bucket's result as a distributed object ref -- never pulled into the driver. + refs = [remote_fn.remote(left_by_bucket[b], right_by_bucket[b]) for b in buckets] + return ray.data.from_arrow_refs(refs) diff --git a/paimon-python/pypaimon/tests/ray_bucket_join_test.py b/paimon-python/pypaimon/tests/ray_bucket_join_test.py new file mode 100644 index 000000000000..661dc6c75352 --- /dev/null +++ b/paimon-python/pypaimon/tests/ray_bucket_join_test.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa +import pytest + +pypaimon = pytest.importorskip("pypaimon") +ray = pytest.importorskip("ray") + +from pypaimon import CatalogFactory, Schema +from pypaimon.ray import bucket_join + + +class RayBucketJoinTest(unittest.TestCase): + """Bucket-aligned join between two HASH_FIXED tables must equal a global join, + with each bucket joined only against the same bucket (no cross-bucket shuffle).""" + + NUM_BUCKETS = 8 + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.catalog_options = {"warehouse": os.path.join(cls.tempdir, "wh")} + cls.catalog = CatalogFactory.create(cls.catalog_options) + cls.catalog.create_database("default", True) + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True, num_cpus=4) + + @classmethod + def tearDownClass(cls): + try: + if ray.is_initialized(): + ray.shutdown() + except Exception: + pass + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _bucketed_table(self, name, schema, key, data): + opts = {"bucket": str(self.NUM_BUCKETS), "bucket-key": key} + self.catalog.create_table(name, Schema.from_pyarrow_schema(schema, options=opts), False) + t = self.catalog.get_table(name) + wb = t.new_batch_write_builder() + w = wb.new_write() + w.write_arrow(data) + wb.new_commit().commit(w.prepare_commit()) + w.close() + return name + + def _create_bucketed(self, name, schema, key, num_buckets): + opts = {"bucket": str(num_buckets), "bucket-key": key} + self.catalog.create_table(name, Schema.from_pyarrow_schema(schema, options=opts), False) + return name + + def test_bucket_join_matches_global_join(self): + loc_schema = pa.schema([("url", pa.string()), ("row_id", pa.int64())]) + in_schema = pa.schema([("url", pa.string())]) + self._bucketed_table( + "default.locator", loc_schema, "url", + pa.Table.from_pydict({"url": [f"u{i}" for i in range(1000)], + "row_id": list(range(1000))}, schema=loc_schema)) + self._bucketed_table( + "default.input", in_schema, "url", + pa.Table.from_pydict({"url": [f"u{i}" for i in range(0, 400)]}, schema=in_schema)) + + ds = bucket_join( + "default.input", "default.locator", self.catalog_options, + on="url", left_projection=["url"], right_projection=["url", "row_id"]) + got = {r["url"]: r["row_id"] for r in ds.take_all()} + + # every input url (u0..u399) is matched to its locator row_id + self.assertEqual(set(got), {f"u{i}" for i in range(400)}) + self.assertEqual(got["u0"], 0) + self.assertEqual(got["u399"], 399) + self.assertTrue(all(got[f"u{i}"] == i for i in range(400))) + + def test_fan_out_one_url_many_row_ids(self): + # A url may map to several locator rows; every match must be emitted. + loc_schema = pa.schema([("url", pa.string()), ("row_id", pa.int64())]) + in_schema = pa.schema([("url", pa.string())]) + self._bucketed_table( + "default.loc_fan", loc_schema, "url", + pa.Table.from_pydict({"url": ["u0", "u0", "u1"], "row_id": [0, 1, 2]}, + schema=loc_schema)) + self._bucketed_table( + "default.in_fan", in_schema, "url", + pa.Table.from_pydict({"url": ["u0"]}, schema=in_schema)) + ds = bucket_join( + "default.in_fan", "default.loc_fan", self.catalog_options, + on="url", left_projection=["url"], right_projection=["url", "row_id"]) + self.assertEqual(sorted(r["row_id"] for r in ds.take_all()), [0, 1]) + + def test_empty_result_keeps_schema(self): + # No shared bucket -> 0 rows, but the join schema must survive. + loc_schema = pa.schema([("url", pa.string()), ("row_id", pa.int64())]) + in_schema = pa.schema([("url", pa.string())]) + self._bucketed_table( + "default.loc_empty", loc_schema, "url", + pa.Table.from_pydict({"url": ["u0", "u1"], "row_id": [0, 1]}, schema=loc_schema)) + self._bucketed_table( + "default.in_empty", in_schema, "url", + pa.Table.from_pydict({"url": []}, schema=in_schema)) # no rows -> no buckets + ds = bucket_join( + "default.in_empty", "default.loc_empty", self.catalog_options, + on="url", left_projection=["url"], right_projection=["url", "row_id"]) + self.assertEqual(ds.count(), 0) + self.assertIn("row_id", ds.schema().names) + + def test_rejects_different_bucket_count(self): + sch = pa.schema([("url", pa.string())]) + self._create_bucketed("default.cnt_8", sch, "url", 8) + self._create_bucketed("default.cnt_16", sch, "url", 16) + with self.assertRaises(ValueError): + bucket_join("default.cnt_8", "default.cnt_16", self.catalog_options, on="url") + + def test_rejects_different_bucket_key(self): + sch = pa.schema([("url", pa.string()), ("k", pa.string())]) + self._create_bucketed("default.by_url", sch, "url", 8) + self._create_bucketed("default.by_k", sch, "k", 8) + with self.assertRaises(ValueError): + bucket_join("default.by_url", "default.by_k", self.catalog_options, on="url") + + def test_rejects_join_key_not_bucket_key(self): + sch = pa.schema([("url", pa.string()), ("k", pa.string())]) + self._create_bucketed("default.k1", sch, "url", 8) + self._create_bucketed("default.k2", sch, "url", 8) + with self.assertRaises(ValueError): # on=k but bucket-key=url + bucket_join("default.k1", "default.k2", self.catalog_options, on="k") + + def test_rejects_non_inner_join(self): + sch = pa.schema([("url", pa.string())]) + self._create_bucketed("default.ji1", sch, "url", 8) + self._create_bucketed("default.ji2", sch, "url", 8) + with self.assertRaises(ValueError): + bucket_join("default.ji1", "default.ji2", self.catalog_options, + on="url", join_type="left outer") + + +if __name__ == "__main__": + unittest.main()