Skip to content

Commit 3868eed

Browse files
committed
tweak spatial join
1 parent 9155110 commit 3868eed

File tree

1 file changed

+112
-43
lines changed

1 file changed

+112
-43
lines changed

src/spatial/operators/spatial_join_physical.cpp

Lines changed: 112 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ namespace duckdb {
2121
// Flat RTree
2222
//======================================================================================================================
2323

24+
// #define DUCKDB_SPATIAL_DEBUG_JOIN
25+
2426
namespace {
2527

2628
template <class T>
@@ -132,11 +134,12 @@ class FlatRTree {
132134
return current_position++;
133135
}
134136

135-
void Sort(vector<uint32_t> &curve) {
136-
Sort(curve, 0, curve.size() - 1);
137+
static void Sort(vector<uint32_t> &curve, typed_view<Box> &box_array, typed_view<uint32_t> &idx_array) {
138+
Sort(curve, box_array, idx_array, 0, curve.size() - 1);
137139
}
138140

139-
void Sort(vector<uint32_t> &curve, size_t l_idx, size_t r_idx) {
141+
static void Sort(vector<uint32_t> &curve, typed_view<Box> &box_array, typed_view<uint32_t> &idx_array, size_t l_idx,
142+
size_t r_idx) {
140143
if (l_idx < r_idx) {
141144
const auto pivot = curve[(l_idx + r_idx) >> 1];
142145
auto pivot_l = l_idx - 1;
@@ -161,8 +164,55 @@ class FlatRTree {
161164
std::swap(idx_array[pivot_l], idx_array[pivot_r]);
162165
}
163166

164-
Sort(curve, l_idx, pivot_r);
165-
Sort(curve, pivot_r + 1, r_idx);
167+
Sort(curve, box_array, idx_array, l_idx, pivot_r);
168+
Sort(curve, box_array, idx_array, pivot_r + 1, r_idx);
169+
}
170+
}
171+
172+
void STRSort(typed_view<Box> &box_array, typed_view<uint32_t> idx_array) {
173+
// Perform Sort-tile-recursive (STR) packing
174+
175+
const auto num_leaf_nodes = (item_count + node_size - 1) / node_size;
176+
const auto num_vertical_slices = static_cast<uint32_t>(std::ceil(std::sqrt(num_leaf_nodes)));
177+
const auto slice_size = (item_count + num_vertical_slices - 1) / num_vertical_slices;
178+
179+
vector<uint32_t> indexes;
180+
for (uint32_t i = 0; i < item_count; i++) {
181+
indexes.push_back(i);
182+
}
183+
184+
// Sort by x-axis into vertical slices
185+
std::sort(indexes.begin(), indexes.end(),
186+
[&](uint32_t a, uint32_t b) { return box_array[a].Center().x < box_array[b].Center().x; });
187+
188+
// Then sort each vertical slice by y-axis
189+
for (uint32_t slice_idx = 0; slice_idx < num_vertical_slices; slice_idx++) {
190+
const auto slice_beg = slice_idx * slice_size;
191+
const auto slice_end = MinValue<uint32_t>(slice_beg + slice_size, item_count);
192+
std::sort(indexes.begin() + slice_beg, indexes.begin() + slice_end,
193+
[&](uint32_t a, uint32_t b) { return box_array[a].Center().y < box_array[b].Center().y; });
194+
}
195+
196+
// Reorder the box_array and idx_array based on the sorted indexes
197+
// DO this in-place. There cannot be any cycles since all indexes are unique
198+
for (int i = 0; i < item_count; i++) {
199+
if (indexes[i] == -1)
200+
continue; // index `i` has been processed, skip
201+
auto box = box_array[i];
202+
auto idx = idx_array[i];
203+
204+
int x = i, y = indexes[i]; // `x` is the current index, `y` is the "target" index
205+
while (y != i) {
206+
indexes[x] = -1; // mark index as processed
207+
box_array[x] = box_array[y];
208+
idx_array[x] = idx_array[y];
209+
x = y;
210+
y = indexes[x];
211+
}
212+
// Now `x` is the index that satisfies `indices[x] == i`.
213+
box_array[x] = box;
214+
idx_array[x] = idx;
215+
indexes[x] = -1;
166216
}
167217
}
168218

@@ -176,6 +226,7 @@ class FlatRTree {
176226

177227
// Generate hilbert curve values
178228
// TODO: Parallelize this with tasks when the number of items is large?
229+
179230
constexpr auto max_hilbert = std::numeric_limits<uint16_t>::max();
180231
const auto hw = max_hilbert / (tree_box.max.x - tree_box.min.x);
181232
const auto hh = max_hilbert / (tree_box.max.y - tree_box.min.y);
@@ -191,7 +242,8 @@ class FlatRTree {
191242
}
192243

193244
// Now, sort the indices based on their curve value
194-
Sort(curve);
245+
Sort(curve, box_array, idx_array);
246+
//STRSort(box_array, idx_array);
195247

196248
size_t layer_idx = 0;
197249
size_t entry_idx = 0;
@@ -723,11 +775,17 @@ class SpatialJoinGlobalOperatorState final : public GlobalOperatorState {
723775
// TODO: Move this into proper profiling metrics later
724776
// Statistics
725777
atomic<idx_t> total_rtree_probes = {0};
778+
atomic<idx_t> total_rtree_successfull_probes = {0};
726779
atomic<idx_t> total_rtree_candidates = {0};
780+
atomic<idx_t> max_candidates = {0};
781+
atomic<idx_t> min_candidates = {std::numeric_limits<idx_t>::max()};
727782

728783
~SpatialJoinGlobalOperatorState() override {
729-
Printer::PrintF("Spatial Join Stats: RTree Probes: %llu, RTree Candidates: %llu\n", total_rtree_probes.load(),
730-
total_rtree_candidates.load());
784+
Printer::PrintF("Spatial Join RTree Probes: %llu\n", total_rtree_probes.load());
785+
Printer::PrintF("Spatial Join RTree Successful Probes: %llu\n", total_rtree_successfull_probes.load());
786+
Printer::PrintF("Spatial Join RTree Candidates: %llu\n", total_rtree_candidates.load());
787+
Printer::PrintF("Spatial Join RTree Max Candidates per Probe: %llu\n", max_candidates.load());
788+
Printer::PrintF("Spatial Join RTree Min Candidates per Probe: %llu\n", min_candidates.load());
731789
}
732790
#endif
733791
};
@@ -843,14 +901,20 @@ OperatorResultType PhysicalSpatialJoin::ExecuteInternal(ExecutionContext &contex
843901

844902
gstate.rtree->InitScan(lstate.scan, bbox);
845903

904+
#ifdef DUCKDB_SPATIAL_DEBUG_JOIN
905+
gstate.total_rtree_probes += 1;
906+
#endif
907+
846908
if (!gstate.rtree->Scan(lstate.scan)) {
847909
lstate.input_index++;
848910
continue;
849911
}
850912

851913
#ifdef DUCKDB_SPATIAL_DEBUG_JOIN
914+
gstate.total_rtree_successfull_probes += 1;
852915
gstate.total_rtree_candidates += lstate.scan.matches_count;
853-
gstate.total_rtree_probes += 1;
916+
gstate.max_candidates = MaxValue(gstate.max_candidates.load(), lstate.scan.matches_count);
917+
gstate.min_candidates = MinValue(gstate.min_candidates.load(), lstate.scan.matches_count);
854918
#endif
855919

856920
lstate.state = SpatialJoinState::SCAN;
@@ -1042,25 +1106,34 @@ class SpatialJoinGlobalSourceState final : public GlobalSourceState {
10421106
column_ids.push_back(op.build_side_key_types.size() + op.build_side_payload_types.size());
10431107

10441108
// We dont need to keep the tuples aroun after scanning
1045-
state.collection->InitializeScan(scan_state, std::move(column_ids),
1046-
TupleDataPinProperties::KEEP_EVERYTHING_PINNED);
1109+
state.collection->InitializeScan(scan_state, std::move(column_ids), TupleDataPinProperties::UNPIN_AFTER_DONE);
10471110

10481111
tuples_maximum = state.collection->Count();
10491112
}
10501113

10511114
const PhysicalSpatialJoin &op;
1115+
1116+
mutex scan_lock;
10521117
TupleDataParallelScanState scan_state;
1118+
10531119
// How many tuples we have scanned so far
10541120
idx_t tuples_maximum = 0;
10551121
atomic<idx_t> tuples_scanned = {0};
10561122

10571123
public:
10581124
idx_t MaxThreads() override {
10591125
const auto &state = op.op_state->Cast<SpatialJoinGlobalOperatorState>();
1060-
const auto count = state.collection->Count();
1126+
return state.collection->ChunkCount();
1127+
}
1128+
1129+
bool Scan(TupleDataLocalScanState &local_scan, DataChunk &chunk) {
1130+
const auto &collection = op.op_state->Cast<SpatialJoinGlobalOperatorState>().collection;
10611131

1062-
// Rough approximation of the number of threads to use
1063-
return count / (STANDARD_VECTOR_SIZE * 10ULL);
1132+
lock_guard<mutex> guard(scan_lock);
1133+
const auto not_empty = collection->Scan(scan_state, local_scan, chunk);
1134+
tuples_scanned += chunk.size();
1135+
1136+
return not_empty;
10641137
}
10651138
};
10661139

@@ -1106,45 +1179,41 @@ SourceResultType PhysicalSpatialJoin::GetDataInternal(ExecutionContext &context,
11061179
auto &gstate = input.global_state.Cast<SpatialJoinGlobalSourceState>();
11071180
auto &lstate = input.local_state.Cast<SpatialJoinLocalSourceState>();
11081181

1109-
const auto &tuples = gstate.op.op_state->Cast<SpatialJoinGlobalOperatorState>().collection;
1110-
1111-
while (tuples->Scan(gstate.scan_state, lstate.scan_state, lstate.scan_chunk)) {
1112-
gstate.tuples_scanned += lstate.scan_chunk.size();
1182+
if (!gstate.Scan(lstate.scan_state, lstate.scan_chunk)) {
1183+
return SourceResultType::FINISHED;
1184+
}
11131185

1114-
const auto matches = FlatVector::GetData<bool>(lstate.scan_chunk.data.back());
1186+
const auto matches = FlatVector::GetData<bool>(lstate.scan_chunk.data.back());
11151187

1116-
idx_t result_count = 0;
1117-
for (idx_t i = 0; i < lstate.scan_chunk.size(); i++) {
1118-
if (!matches[i]) {
1119-
lstate.match_sel.set_index(result_count++, i);
1120-
}
1188+
idx_t result_count = 0;
1189+
for (idx_t i = 0; i < lstate.scan_chunk.size(); i++) {
1190+
if (!matches[i]) {
1191+
lstate.match_sel.set_index(result_count++, i);
11211192
}
1193+
}
11221194

1123-
if (result_count > 0) {
1195+
if (result_count > 0) {
11241196

1125-
const auto lhs_col_count = probe_side_output_columns.size();
1126-
const auto rhs_col_count = build_side_output_columns.size();
1197+
const auto lhs_col_count = probe_side_output_columns.size();
1198+
const auto rhs_col_count = build_side_output_columns.size();
11271199

1128-
// Null the LHS columns
1129-
for (idx_t i = 0; i < lhs_col_count; i++) {
1130-
auto &target = chunk.data[i];
1131-
target.SetVectorType(VectorType::CONSTANT_VECTOR);
1132-
ConstantVector::SetNull(target, true);
1133-
}
1134-
1135-
// Set the RHS columns
1136-
for (idx_t i = 0; i < rhs_col_count; i++) {
1137-
auto &target = chunk.data[lhs_col_count + i];
1138-
// Offset by one here to skip the match column
1139-
target.Slice(lstate.scan_chunk.data[i], lstate.match_sel, result_count);
1140-
}
1200+
// Null the LHS columns
1201+
for (idx_t i = 0; i < lhs_col_count; i++) {
1202+
auto &target = chunk.data[i];
1203+
target.SetVectorType(VectorType::CONSTANT_VECTOR);
1204+
ConstantVector::SetNull(target, true);
1205+
}
11411206

1142-
chunk.SetCardinality(result_count);
1143-
return SourceResultType::HAVE_MORE_OUTPUT;
1207+
// Set the RHS columns
1208+
for (idx_t i = 0; i < rhs_col_count; i++) {
1209+
auto &target = chunk.data[lhs_col_count + i];
1210+
// Offset by one here to skip the match column
1211+
target.Slice(lstate.scan_chunk.data[i], lstate.match_sel, result_count);
11441212
}
11451213
}
11461214

1147-
return SourceResultType::FINISHED;
1215+
chunk.SetCardinality(result_count);
1216+
return SourceResultType::HAVE_MORE_OUTPUT;
11481217
}
11491218

11501219
//----------------------------------------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)