Skip to content

Commit

Permalink
repo-sync-2024-06-06T17:12:08+0800 (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
huocun-ant authored Jun 6, 2024
1 parent ef7c18b commit 99f708d
Show file tree
Hide file tree
Showing 24 changed files with 1,092 additions and 249 deletions.
4 changes: 2 additions & 2 deletions .ci/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def get_s3_configs(self) -> Dict[str, dict]:

@dataclass
class PartyConfig(ServerBaseConifg):
model_package_path: str
csv_path: str
model_package_path: str = None
csv_path: str = None
query_datas: List[str] = None
query_context: str = None

Expand Down
52 changes: 35 additions & 17 deletions .ci/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ def test(self, config):


class CsvTest(TestCase):
def __init__(self, path: str, use_http_feature_source: bool = False):
super().__init__(path)
self.use_http_feature_source = use_http_feature_source

def get_config(self, path: str):
dot_node_alice = make_dot_product_node_def(
name="node_dot_product",
Expand Down Expand Up @@ -300,6 +304,7 @@ def get_config(self, path: str):
"v1": [1, 2, 3, 4],
"v2": [5, 6, 7, 8],
},
use_http_feature_source=self.use_http_feature_source,
)
bob_config = PartyConfig(
id="bob",
Expand All @@ -310,10 +315,11 @@ def get_config(self, path: str):
graph_def=bob_graph,
query_datas=["a", "b", "c"], # Corresponds to the id column in csv
csv_dict={
"id": ["a", "b", "c"],
"vv3": [1, 2, 3],
"vv2": [5, 6, 7],
"id": ["b", "a", "c"],
"vv3": [2, 1, 3],
"vv2": [6, 5, 7],
},
use_http_feature_source=self.use_http_feature_source,
)
return TestConfig(
path,
Expand Down Expand Up @@ -633,13 +639,11 @@ def get_spans_from_trace_file(trace_filename):
), f"execution parent span id mismatch, expected: {service_span_id_dict['bob']}, actual: {stub_span_id_dict['alice']}"


if __name__ == "__main__":
ExampleTest('examples').exec()

def get_mock_feature_glm_test():
# glm
with open(".ci/simple_test/node_processing_alice.json", "rb") as f:
alice_trace_content = f.read()
MockFeatureTest(
return MockFeatureTest(
service_id="glm",
path='model_path',
nodes={
Expand Down Expand Up @@ -811,10 +815,11 @@ def get_spans_from_trace_file(trace_filename):
"v10": "x10",
},
},
).exec()
)

# sgb
MockFeatureTest(

def get_mock_feature_sgb_test():
return MockFeatureTest(
service_id="sgb",
path='sgb_model',
nodes={
Expand Down Expand Up @@ -1484,10 +1489,11 @@ def get_spans_from_trace_file(trace_filename):
],
},
specific_party="alice",
).exec()
)


# sgb with alice no feature
MockFeatureTest(
def get_sgb_alice_no_feature_test():
return MockFeatureTest(
service_id="sgb_bob_no_feature",
path='sgb_model',
nodes={
Expand Down Expand Up @@ -1893,10 +1899,11 @@ def get_spans_from_trace_file(trace_filename):
],
},
specific_party="alice",
).exec()
)

# glm with bob no feature
MockFeatureTest(

def get_glm_bob_no_feature_test():
return MockFeatureTest(
service_id="glm_with_bob_no_feature",
path='model_path',
nodes={
Expand Down Expand Up @@ -1988,9 +1995,20 @@ def get_spans_from_trace_file(trace_filename):
},
"bob": {},
},
).exec()
)


if __name__ == "__main__":
ExampleTest('examples').exec()

get_mock_feature_glm_test().exec()
get_mock_feature_sgb_test().exec()

get_sgb_alice_no_feature_test().exec()
get_glm_bob_no_feature_test().exec()

PredefinedErrorTest('model_path').exec()
PredefineTest('model_path').exec()
CsvTest('model_path').exec()
CsvTest('model_path', use_http_feature_source=True).exec()
SpecificTest('model_path').exec()
106 changes: 78 additions & 28 deletions .ci/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def global_ip_config(index):
communication_port = [8710, 8711]
metrics_port = [8318, 8319]
brpc_builtin_port = [8328, 8329]
http_feature_source_port = [8338, 8339]
assert index < len(cluster_ip)
return {
"cluster_ip": cluster_ip[index],
Expand All @@ -74,6 +75,7 @@ def global_ip_config(index):
"communication_port": communication_port[index],
"metrics_port": metrics_port[index],
"brpc_builtin_service_port": brpc_builtin_port[index],
"http_feature_source_port": http_feature_source_port[index],
}


Expand Down Expand Up @@ -112,6 +114,26 @@ def build_get_model_info_cmd(host: str, port: int, service_spec_id: str):
return f'curl --location "{url}" --header "Content-Type: application/json" --data \'{json.dumps(body_dict)}\''


class ProcRunGuard:
def __init__(self):
self.sub_proc_list = []

def __del__(self):
self.cleanup_sub_procs()

def run_cmd(self, cmd, background=False, envs=None):
ret = exec_cmd(cmd, background, envs)
if background:
self.sub_proc_list.append(ret)
else:
return ret

def cleanup_sub_procs(self):
for proc in self.sub_proc_list:
proc.kill()
proc.wait()


class ModelBuilder:
def __init__(self, name, desc, graph_def: GraphDef):
self.name = name
Expand Down Expand Up @@ -361,15 +383,17 @@ class ServerBaseConifg:
brpc_builtin_service_port: int
channel_protocol: str
model_id: str
http_feature_source_port: int = None


@dataclass
class PartyConfig(ServerBaseConifg):
feature_mapping: Dict[str, str]
graph_def: GraphDef
feature_mapping: Dict[str, str] = None
graph_def: GraphDef = None
query_datas: List[str] = None
query_context: str = None
csv_dict: Dict[str, Any] = None
csv_dict: Dict[str, Any] = None # must contain id column
use_http_feature_source: bool = False


class ConfigDumper:
Expand All @@ -380,13 +404,15 @@ def __init__(
serving_config_filename: str,
tar_name: str,
service_id: str,
proc_guard: ProcRunGuard = None,
):
self.service_id = service_id
self.party_configs = party_configs
self.parties = []
self.log_config = log_config_filename
self.serving_config = serving_config_filename
self.tar_name = tar_name
self.proc_guard = proc_guard
for config in self.party_configs:
self.parties.append({"id": config.id, "address": config.cluster_ip})

Expand Down Expand Up @@ -424,6 +450,53 @@ def make_csv_config(self, data_dict: Dict[str, List[Any]], path: str):
def _dump_serving_config(
self, path: str, config: PartyConfig, model_name: str, model_sha256: str
):
def dump_csv(path: str, data_dict: Dict[str, List[Any]]) -> str:
filename = "feature_source.csv"
file_path = os.path.join(path, filename)
with open(file_path, "w") as ofile:
writer = csv.DictWriter(ofile, fieldnames=list(data_dict.keys()))
writer.writeheader()
rows = []
for key, value in data_dict.items():
if len(rows) == 0:
rows = [{} for _ in value]
assert len(value) == len(
rows
), f"row count {len(value)} of {key} in data_dict is diff with {len(rows)}."
for i in range(len(value)):
rows[i][key] = value[i]
print("CSV Rows: ", rows)
for row in rows:
writer.writerow(row)
return file_path

def start_simple_feature_service(port: int, csv_path: str):
self.proc_guard.run_cmd(
f"./bazel-bin/secretflow_serving/tools/simple_feature_service/simple_feature_service --port={port} --csv_filename={csv_path} --csv_id_column_name=id ",
background=True,
)

feature_source_config = {"mockOpts": {"type": "MDT_RANDOM"}}

if config.csv_dict:
assert 'id' in config.csv_dict, 'csv_dict should contain id column'
csv_path = dump_csv(path, config.csv_dict)

if config.use_http_feature_source:
assert (
config.csv_dict is not None
), "http_feature_source require a csv file, so csv_dict must be set"
start_simple_feature_service(config.http_feature_source_port, csv_path)
feature_source_config = {
"http_opts": {
"endpoint": f"127.0.0.1:{config.http_feature_source_port}/BatchFeatureService/BatchFetchFeature"
}
}
elif config.csv_dict:
feature_source_config = {
"csv_opts": {"file_path": csv_path, "id_name": "id"}
}

config_dict = {
"id": self.service_id,
"serverConf": {
Expand All @@ -446,11 +519,7 @@ def _dump_serving_config(
"parties": self.parties,
"channel_desc": {"protocol": config.channel_protocol},
},
"featureSourceConf": (
self.make_csv_config(config.csv_dict, path)
if config.csv_dict
else {"mockOpts": {"type": "MDT_RANDOM"}}
),
"featureSourceConf": feature_source_config,
}
dump_json(config_dict, os.path.join(path, self.serving_config))

Expand All @@ -466,26 +535,6 @@ def dump(self, path="."):
self._dump_serving_config(config_path, config, model_name, model_sha256)


class ProcRunGuard:
def __init__(self):
self.sub_proc_list = []

def __del__(self):
self.cleanup_sub_procs()

def run_cmd(self, cmd, background=False, envs=None):
ret = exec_cmd(cmd, background, envs)
if background:
self.sub_proc_list.append(ret)
else:
return ret

def cleanup_sub_procs(self):
for proc in self.sub_proc_list:
proc.kill()
proc.wait()


# for every testcase, there should be a TestConfig instance
class TestConfig(ProcRunGuard):
def __init__(
Expand Down Expand Up @@ -525,6 +574,7 @@ def dump_config(self):
self.serving_config_name,
self.tar_name,
self.service_spec_id,
self,
).dump(self.model_path)

def get_server_start_args(self):
Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@

> please add your unreleased change here.
- [Feature] `MERGE_Y` operator support `EXP_TAYLOR` link function

## 20240524 - 0.3.1b0

- [Bugfix] fix tree predict base score miss
- [Bugfix] fix http adapater error msg format failed

## 20240423 - 0.3.0b0

- [Feature] Add Trace function
Expand Down
22 changes: 3 additions & 19 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,17 @@ def sf_serving_deps():
_kuscia()
_yacl()

_com_github_madler_zlib()
_com_github_brpc_brpc()

def _yacl():
maybe(
http_archive,
name = "yacl",
urls = [
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.4b1.tar.gz",
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b1.tar.gz",
],
strip_prefix = "yacl-0.4.4b1",
sha256 = "846f40f5e00cfb77427016623b9cb215e5150384a3cc7c663642d995ee7afec9",
strip_prefix = "yacl-0.4.5b1",
sha256 = "28064053b9add0db8e1e8e648421a0579f1d3e7ee8a4bbd7bd5959cb59598088",
)

def _kuscia():
Expand Down Expand Up @@ -358,21 +357,6 @@ def _rules_proto_grpc():

# serving not use brpc x-bd-xxx trace header, so drop the patch of yacl
# add for brpc compile
def _com_github_madler_zlib():
maybe(
http_archive,
name = "zlib",
build_file = "@yacl//bazel:zlib.BUILD",
strip_prefix = "zlib-1.3.1",
sha256 = "17e88863f3600672ab49182f217281b6fc4d3c762bde361935e436a95214d05c",
type = ".tar.gz",
patch_args = ["-p1"],
patches = ["@yacl//bazel:patches/zlib.patch"],
urls = [
"https://github.com/madler/zlib/archive/refs/tags/v1.3.1.tar.gz",
],
)

def _com_github_brpc_brpc():
maybe(
http_archive,
Expand Down
Loading

0 comments on commit 99f708d

Please sign in to comment.