@@ -21,6 +21,8 @@ namespace duckdb {
2121// Flat RTree
2222// ======================================================================================================================
2323
24+ // #define DUCKDB_SPATIAL_DEBUG_JOIN
25+
2426namespace {
2527
2628template <class T >
@@ -132,11 +134,12 @@ class FlatRTree {
132134 return current_position++;
133135 }
134136
135- void Sort (vector<uint32_t > &curve) {
136- Sort (curve, 0 , curve.size () - 1 );
137+ static void Sort (vector<uint32_t > &curve, typed_view<Box> &box_array, typed_view< uint32_t > &idx_array ) {
138+ Sort (curve, box_array, idx_array, 0 , curve.size () - 1 );
137139 }
138140
139- void Sort (vector<uint32_t > &curve, size_t l_idx, size_t r_idx) {
141+ static void Sort (vector<uint32_t > &curve, typed_view<Box> &box_array, typed_view<uint32_t > &idx_array, size_t l_idx,
142+ size_t r_idx) {
140143 if (l_idx < r_idx) {
141144 const auto pivot = curve[(l_idx + r_idx) >> 1 ];
142145 auto pivot_l = l_idx - 1 ;
@@ -161,8 +164,55 @@ class FlatRTree {
161164 std::swap (idx_array[pivot_l], idx_array[pivot_r]);
162165 }
163166
164- Sort (curve, l_idx, pivot_r);
165- Sort (curve, pivot_r + 1 , r_idx);
167+ Sort (curve, box_array, idx_array, l_idx, pivot_r);
168+ Sort (curve, box_array, idx_array, pivot_r + 1 , r_idx);
169+ }
170+ }
171+
172+ void STRSort (typed_view<Box> &box_array, typed_view<uint32_t > idx_array) {
173+ // Perform Sort-tile-recursive (STR) packing
174+
175+ const auto num_leaf_nodes = (item_count + node_size - 1 ) / node_size;
176+ const auto num_vertical_slices = static_cast <uint32_t >(std::ceil (std::sqrt (num_leaf_nodes)));
177+ const auto slice_size = (item_count + num_vertical_slices - 1 ) / num_vertical_slices;
178+
179+ vector<uint32_t > indexes;
180+ for (uint32_t i = 0 ; i < item_count; i++) {
181+ indexes.push_back (i);
182+ }
183+
184+ // Sort by x-axis into vertical slices
185+ std::sort (indexes.begin (), indexes.end (),
186+ [&](uint32_t a, uint32_t b) { return box_array[a].Center ().x < box_array[b].Center ().x ; });
187+
188+ // Then sort each vertical slice by y-axis
189+ for (uint32_t slice_idx = 0 ; slice_idx < num_vertical_slices; slice_idx++) {
190+ const auto slice_beg = slice_idx * slice_size;
191+ const auto slice_end = MinValue<uint32_t >(slice_beg + slice_size, item_count);
192+ std::sort (indexes.begin () + slice_beg, indexes.begin () + slice_end,
193+ [&](uint32_t a, uint32_t b) { return box_array[a].Center ().y < box_array[b].Center ().y ; });
194+ }
195+
196+ // Reorder the box_array and idx_array based on the sorted indexes
197+ // DO this in-place. There cannot be any cycles since all indexes are unique
198+ for (int i = 0 ; i < item_count; i++) {
199+ if (indexes[i] == -1 )
200+ continue ; // index `i` has been processed, skip
201+ auto box = box_array[i];
202+ auto idx = idx_array[i];
203+
204+ int x = i, y = indexes[i]; // `x` is the current index, `y` is the "target" index
205+ while (y != i) {
206+ indexes[x] = -1 ; // mark index as processed
207+ box_array[x] = box_array[y];
208+ idx_array[x] = idx_array[y];
209+ x = y;
210+ y = indexes[x];
211+ }
212+ // Now `x` is the index that satisfies `indices[x] == i`.
213+ box_array[x] = box;
214+ idx_array[x] = idx;
215+ indexes[x] = -1 ;
166216 }
167217 }
168218
@@ -176,6 +226,7 @@ class FlatRTree {
176226
177227 // Generate hilbert curve values
178228 // TODO: Parallelize this with tasks when the number of items is large?
229+
179230 constexpr auto max_hilbert = std::numeric_limits<uint16_t >::max ();
180231 const auto hw = max_hilbert / (tree_box.max .x - tree_box.min .x );
181232 const auto hh = max_hilbert / (tree_box.max .y - tree_box.min .y );
@@ -191,7 +242,8 @@ class FlatRTree {
191242 }
192243
193244 // Now, sort the indices based on their curve value
194- Sort (curve);
245+ Sort (curve, box_array, idx_array);
246+ // STRSort(box_array, idx_array);
195247
196248 size_t layer_idx = 0 ;
197249 size_t entry_idx = 0 ;
@@ -723,11 +775,17 @@ class SpatialJoinGlobalOperatorState final : public GlobalOperatorState {
723775 // TODO: Move this into proper profiling metrics later
724776 // Statistics
725777 atomic<idx_t > total_rtree_probes = {0 };
778+ atomic<idx_t > total_rtree_successfull_probes = {0 };
726779 atomic<idx_t > total_rtree_candidates = {0 };
780+ atomic<idx_t > max_candidates = {0 };
781+ atomic<idx_t > min_candidates = {std::numeric_limits<idx_t >::max ()};
727782
728783 ~SpatialJoinGlobalOperatorState () override {
729- Printer::PrintF (" Spatial Join Stats: RTree Probes: %llu, RTree Candidates: %llu\n " , total_rtree_probes.load (),
730- total_rtree_candidates.load ());
784+ Printer::PrintF (" Spatial Join RTree Probes: %llu\n " , total_rtree_probes.load ());
785+ Printer::PrintF (" Spatial Join RTree Successful Probes: %llu\n " , total_rtree_successfull_probes.load ());
786+ Printer::PrintF (" Spatial Join RTree Candidates: %llu\n " , total_rtree_candidates.load ());
787+ Printer::PrintF (" Spatial Join RTree Max Candidates per Probe: %llu\n " , max_candidates.load ());
788+ Printer::PrintF (" Spatial Join RTree Min Candidates per Probe: %llu\n " , min_candidates.load ());
731789 }
732790#endif
733791};
@@ -843,14 +901,20 @@ OperatorResultType PhysicalSpatialJoin::ExecuteInternal(ExecutionContext &contex
843901
844902 gstate.rtree ->InitScan (lstate.scan , bbox);
845903
904+ #ifdef DUCKDB_SPATIAL_DEBUG_JOIN
905+ gstate.total_rtree_probes += 1 ;
906+ #endif
907+
846908 if (!gstate.rtree ->Scan (lstate.scan )) {
847909 lstate.input_index ++;
848910 continue ;
849911 }
850912
851913#ifdef DUCKDB_SPATIAL_DEBUG_JOIN
914+ gstate.total_rtree_successfull_probes += 1 ;
852915 gstate.total_rtree_candidates += lstate.scan .matches_count ;
853- gstate.total_rtree_probes += 1 ;
916+ gstate.max_candidates = MaxValue (gstate.max_candidates .load (), lstate.scan .matches_count );
917+ gstate.min_candidates = MinValue (gstate.min_candidates .load (), lstate.scan .matches_count );
854918#endif
855919
856920 lstate.state = SpatialJoinState::SCAN;
@@ -1042,25 +1106,34 @@ class SpatialJoinGlobalSourceState final : public GlobalSourceState {
10421106 column_ids.push_back (op.build_side_key_types .size () + op.build_side_payload_types .size ());
10431107
10441108 // We dont need to keep the tuples aroun after scanning
1045- state.collection ->InitializeScan (scan_state, std::move (column_ids),
1046- TupleDataPinProperties::KEEP_EVERYTHING_PINNED);
1109+ state.collection ->InitializeScan (scan_state, std::move (column_ids), TupleDataPinProperties::UNPIN_AFTER_DONE);
10471110
10481111 tuples_maximum = state.collection ->Count ();
10491112 }
10501113
10511114 const PhysicalSpatialJoin &op;
1115+
1116+ mutex scan_lock;
10521117 TupleDataParallelScanState scan_state;
1118+
10531119 // How many tuples we have scanned so far
10541120 idx_t tuples_maximum = 0 ;
10551121 atomic<idx_t > tuples_scanned = {0 };
10561122
10571123public:
10581124 idx_t MaxThreads () override {
10591125 const auto &state = op.op_state ->Cast <SpatialJoinGlobalOperatorState>();
1060- const auto count = state.collection ->Count ();
1126+ return state.collection ->ChunkCount ();
1127+ }
1128+
1129+ bool Scan (TupleDataLocalScanState &local_scan, DataChunk &chunk) {
1130+ const auto &collection = op.op_state ->Cast <SpatialJoinGlobalOperatorState>().collection ;
10611131
1062- // Rough approximation of the number of threads to use
1063- return count / (STANDARD_VECTOR_SIZE * 10ULL );
1132+ lock_guard<mutex> guard (scan_lock);
1133+ const auto not_empty = collection->Scan (scan_state, local_scan, chunk);
1134+ tuples_scanned += chunk.size ();
1135+
1136+ return not_empty;
10641137 }
10651138};
10661139
@@ -1106,45 +1179,41 @@ SourceResultType PhysicalSpatialJoin::GetDataInternal(ExecutionContext &context,
11061179 auto &gstate = input.global_state .Cast <SpatialJoinGlobalSourceState>();
11071180 auto &lstate = input.local_state .Cast <SpatialJoinLocalSourceState>();
11081181
1109- const auto &tuples = gstate.op .op_state ->Cast <SpatialJoinGlobalOperatorState>().collection ;
1110-
1111- while (tuples->Scan (gstate.scan_state , lstate.scan_state , lstate.scan_chunk )) {
1112- gstate.tuples_scanned += lstate.scan_chunk .size ();
1182+ if (!gstate.Scan (lstate.scan_state , lstate.scan_chunk )) {
1183+ return SourceResultType::FINISHED;
1184+ }
11131185
1114- const auto matches = FlatVector::GetData<bool >(lstate.scan_chunk .data .back ());
1186+ const auto matches = FlatVector::GetData<bool >(lstate.scan_chunk .data .back ());
11151187
1116- idx_t result_count = 0 ;
1117- for (idx_t i = 0 ; i < lstate.scan_chunk .size (); i++) {
1118- if (!matches[i]) {
1119- lstate.match_sel .set_index (result_count++, i);
1120- }
1188+ idx_t result_count = 0 ;
1189+ for (idx_t i = 0 ; i < lstate.scan_chunk .size (); i++) {
1190+ if (!matches[i]) {
1191+ lstate.match_sel .set_index (result_count++, i);
11211192 }
1193+ }
11221194
1123- if (result_count > 0 ) {
1195+ if (result_count > 0 ) {
11241196
1125- const auto lhs_col_count = probe_side_output_columns.size ();
1126- const auto rhs_col_count = build_side_output_columns.size ();
1197+ const auto lhs_col_count = probe_side_output_columns.size ();
1198+ const auto rhs_col_count = build_side_output_columns.size ();
11271199
1128- // Null the LHS columns
1129- for (idx_t i = 0 ; i < lhs_col_count; i++) {
1130- auto &target = chunk.data [i];
1131- target.SetVectorType (VectorType::CONSTANT_VECTOR);
1132- ConstantVector::SetNull (target, true );
1133- }
1134-
1135- // Set the RHS columns
1136- for (idx_t i = 0 ; i < rhs_col_count; i++) {
1137- auto &target = chunk.data [lhs_col_count + i];
1138- // Offset by one here to skip the match column
1139- target.Slice (lstate.scan_chunk .data [i], lstate.match_sel , result_count);
1140- }
1200+ // Null the LHS columns
1201+ for (idx_t i = 0 ; i < lhs_col_count; i++) {
1202+ auto &target = chunk.data [i];
1203+ target.SetVectorType (VectorType::CONSTANT_VECTOR);
1204+ ConstantVector::SetNull (target, true );
1205+ }
11411206
1142- chunk.SetCardinality (result_count);
1143- return SourceResultType::HAVE_MORE_OUTPUT;
1207+ // Set the RHS columns
1208+ for (idx_t i = 0 ; i < rhs_col_count; i++) {
1209+ auto &target = chunk.data [lhs_col_count + i];
1210+ // Offset by one here to skip the match column
1211+ target.Slice (lstate.scan_chunk .data [i], lstate.match_sel , result_count);
11441212 }
11451213 }
11461214
1147- return SourceResultType::FINISHED;
1215+ chunk.SetCardinality (result_count);
1216+ return SourceResultType::HAVE_MORE_OUTPUT;
11481217}
11491218
11501219// ----------------------------------------------------------------------------------------------------------------------
0 commit comments