Skip to content

Commit

Permalink
sync updates for group by
Browse files Browse the repository at this point in the history
  • Loading branch information
jingshi-ant committed Jul 27, 2023
1 parent e3ae8f1 commit 819b969
Show file tree
Hide file tree
Showing 21 changed files with 779 additions and 63 deletions.
14 changes: 8 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

### Fixed

## [0.1.0] - 2023-03-28
- Speed up GROUP BY with HEU in some scenarios.

### Added

- SCQL init release
### Fixed

## [0.2.0] - 2023-07-05

Expand Down Expand Up @@ -55,3 +51,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed create database failed [#19](https://github.com/secretflow/scql/issues/19).
- Fixed not support group by string[#48](https://github.com/secretflow/scql/pull/48).

## [0.1.0] - 2023-03-28

### Added

- SCQL init release
7 changes: 7 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ load("@spulib//bazel:repositories.bzl", "spu_deps")

spu_deps()

#
# heu
#
load("@com_alipay_sf_heu//third_party/bazel_cpp:repositories.bzl", "heu_cpp_deps")

heu_cpp_deps()

#
# yacl
#
Expand Down
13 changes: 9 additions & 4 deletions cmd/regtest/testdata/two_parties.json
Original file line number Diff line number Diff line change
Expand Up @@ -316,25 +316,30 @@
"mysql_query": "select alice.compare_long_0 not in (select compare_long_0 from bob.tbl_0) as not_in_column from alice.tbl_0 as alice where alice.compare_long_0 > 1000;"
},
{
"name": "agg with group by(long)",
"name": "agg with group by(long), private group by",
"query": "select ta.groupby_long_0, sum(ta.aggregate_long_0) as b, max(ta.aggregate_long_0) as a, min(ta.aggregate_long_0) as d from scdb.alice_tbl_0 as ta join scdb.bob_tbl_0 as tb on ta.plain_long_0 = tb.plain_long_0 group by ta.groupby_long_0",
"mysql_query": "select ta.groupby_long_0, sum(ta.aggregate_long_0) as b, max(ta.aggregate_long_0) as a, min(ta.aggregate_long_0) as d from alice.tbl_0 as ta join bob.tbl_0 as tb on ta.plain_long_0 = tb.plain_long_0 group by ta.groupby_long_0 having count(*) >= 4"
},
{
"name": "agg with group by(float)",
"name": "agg with group by(float), private group by",
"query": "select ta.groupby_long_0, sum(ta.aggregate_float_0) as b, max(ta.aggregate_float_0 + ta.aggregate_float_0) as a, min(ta.aggregate_float_0) as d from scdb.alice_tbl_0 as ta join scdb.bob_tbl_0 as tb on ta.plain_long_0 = tb.plain_long_0 group by ta.groupby_long_0",
"mysql_query": "select ta.groupby_long_0, sum(ta.aggregate_float_0) as b, max(ta.aggregate_float_0 + ta.aggregate_float_0) as a, min(ta.aggregate_float_0) as d from alice.tbl_0 as ta join bob.tbl_0 as tb on ta.plain_long_0 = tb.plain_long_0 group by ta.groupby_long_0 having count(*) >= 4"
},
{
"name": "agg with group by(string)",
"name": "agg with group by(string), private group by",
"query": "select ta.groupby_string_0, count(*) as c from scdb.alice_tbl_0 as ta join scdb.bob_tbl_0 as tb on ta.plain_long_0 = tb.plain_long_0 group by ta.groupby_string_0",
"mysql_query": "select ta.groupby_string_0, count(*) as c from alice.tbl_0 as ta join bob.tbl_0 as tb on ta.plain_long_0 = tb.plain_long_0 group by ta.groupby_string_0 having count(*) >= 4"
},
{
"name": "agg with group by(string2)",
"name": "agg with group by(string2), oblivious group by",
"query": "select ta.groupby_string_0,tb.groupby_string_0, count(*) as c from scdb.alice_tbl_0 as ta join scdb.bob_tbl_0 as tb on ta.plain_long_0 = tb.plain_long_0 group by ta.groupby_string_0,tb.groupby_string_0",
"mysql_query": "select ta.groupby_string_0,tb.groupby_string_0, count(*) as c from alice.tbl_0 as ta join bob.tbl_0 as tb on ta.plain_long_0 = tb.plain_long_0 group by ta.groupby_string_0,tb.groupby_string_0 having count(*) >= 4"
},
{
"name": "he group by",
"query": "select ta.groupby_string_0, sum(tb.aggregate_long_0) as sl, sum(tb.aggregate_float_0) as sf, sum(ta.compare_long_0 > tb.compare_long_0) as sc, count(tb.encrypt_long_0) as c from scdb.alice_tbl_0 as ta join scdb.bob_tbl_0 as tb on ta.join_long_0 = tb.join_long_0 group by ta.groupby_string_0",
"mysql_query": "select ta.groupby_string_0, sum(tb.aggregate_long_0) as sl, sum(tb.aggregate_float_0) as sf, sum(ta.compare_long_0 > tb.compare_long_0) as sc, count(tb.encrypt_long_0) as c from alice.tbl_0 as ta join bob.tbl_0 as tb on ta.join_long_0 = tb.join_long_0 group by ta.groupby_string_0 having count(*) >= 4;"
},
{
"name": "select plain data from another party",
"query": "select plain_long_0, plain_string_0, plain_float_0 from scdb.bob_tbl_0",
Expand Down
8 changes: 8 additions & 0 deletions engine/bazel/engine_deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")


SECRETFLOW_GROUP_GIT = "https://github.com/secretflow"

SPU_GIT = "https://github.com/secretflow/spu.git"


def engine_deps():
_com_github_nelhage_rules_boost()
_org_apache_arrow()
Expand Down Expand Up @@ -48,6 +50,12 @@ def engine_deps():
tag = "0.4.1b1",
remote = SPU_GIT,
)
maybe(
git_repository,
name = "com_alipay_sf_heu",
tag = "v0.4.4b0",
remote = "https://github.com/secretflow/heu.git",
)

def _org_apache_arrow():
maybe(
Expand Down
25 changes: 25 additions & 0 deletions engine/operator/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cc_library(
":filter_by_index",
":group",
":group_agg",
":group_he_sum",
":in",
":join",
":logical",
Expand Down Expand Up @@ -697,3 +698,27 @@ cc_test(
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "group_he_sum",
srcs = ["group_he_sum.cc"],
hdrs = ["group_he_sum.h"],
deps = [
"//engine/framework:operator",
"//engine/util:spu_io",
"//engine/util:tensor_util",
"@com_alipay_sf_heu//heu/library/numpy",
"@com_github_gflags_gflags//:gflags",
],
)

cc_test(
name = "group_he_sum_test",
srcs = ["group_he_sum_test.cc"],
deps = [
":group_he_sum",
":test_util",
"//engine/core:tensor_from_json",
"@com_google_googletest//:gtest_main",
],
)
2 changes: 2 additions & 0 deletions engine/operator/all_ops_register.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "engine/operator/filter_by_index.h"
#include "engine/operator/group.h"
#include "engine/operator/group_agg.h"
#include "engine/operator/group_he_sum.h"
#include "engine/operator/in.h"
#include "engine/operator/join.h"
#include "engine/operator/logical.h"
Expand Down Expand Up @@ -113,6 +114,7 @@ void RegisterAllOpsImpl() {
ADD_OPERATOR_TO_REGISTRY(GroupAvg);
ADD_OPERATOR_TO_REGISTRY(GroupMin);
ADD_OPERATOR_TO_REGISTRY(GroupMax);
ADD_OPERATOR_TO_REGISTRY(GroupHeSum);

// oblivious groupby
ADD_OPERATOR_TO_REGISTRY(ObliviousGroupMark);
Expand Down
Loading

0 comments on commit 819b969

Please sign in to comment.