diff --git a/.github/workflows/pull-request-check.yaml b/.github/workflows/pull-request-check.yaml index 5c39ae94..e4e6bb4d 100644 --- a/.github/workflows/pull-request-check.yaml +++ b/.github/workflows/pull-request-check.yaml @@ -33,7 +33,7 @@ jobs: be+fe cdk others - subjectPattern: ^[\s\w-\.]{5,100}$ + subjectPattern: ^[\s\w-\./]{5,100}$ subjectPatternError: | The subject "{subject}" found in the pull request title "{title}" didn't match the configured pattern. Please ensure that the subject diff --git a/.semgrepignore b/.semgrepignore new file mode 100644 index 00000000..01205580 --- /dev/null +++ b/.semgrepignore @@ -0,0 +1,19 @@ +# 忽略所有在 tests 目录下的文件 +tests/ + +# 忽略所有的 .json 文件 +*.json + +# 忽略特定的文件 +.gitignore +deployment/cdk-solution-helper/index.js +deployment/cdk-solution-helper/index.js +deployment/cdk-solution-helper/index.js +deployment/helper.py +source/portal/config/env.js +source/portal/config/modules.js +source/portal/config/modules.js +source/portal/config/modules.js +source/portal/config/paths.js +source/portal/nginx-config/start_nginx.sh +.github/* diff --git a/buildspec.yml b/buildspec.yml index 4bbff369..32385750 100755 --- a/buildspec.yml +++ b/buildspec.yml @@ -23,7 +23,7 @@ phases: - chmod +x ./run-all-tests.sh && ./run-all-tests.sh - echo "Installing dependencies and executing unit tests completed `date`" - export BSS_IMAGE_ASSET_REPOSITORY_NAME='aws-sensitive-data-protection' - - export BUILD_VERSION=1.1.1-${CODEBUILD_RESOLVED_SOURCE_VERSION:0:7} + - export BUILD_VERSION=1.1.2-${CODEBUILD_RESOLVED_SOURCE_VERSION:0:7} - export CN_ASSETS='cn/' - |- set -euxo pipefail diff --git a/deployment/build-s3-dist.sh b/deployment/build-s3-dist.sh index 9970619a..9334942d 100755 --- a/deployment/build-s3-dist.sh +++ b/deployment/build-s3-dist.sh @@ -42,6 +42,8 @@ title "cdk synth" run cd ${SRC_PATH} # Replace before building +sed -i "s|DEBUG|INFO|"g api/logging.conf + sed -i "s|@TEMPLATE_SOLUTION_VERSION@|$SOLUTION_VERSION|"g lib/admin/database/*/*.sql sed -i "s|@TEMPLATE_SOLUTION_VERSION@|$SOLUTION_VERSION|"g lib/agent/DiscoveryJob.json diff --git a/source/.viperlightignore b/source/.viperlightignore index 83156e5d..91dcc3b3 100755 --- a/source/.viperlightignore +++ b/source/.viperlightignore @@ -18,6 +18,10 @@ constructs/api/pytest/test_data_source.py constructs/api/pytest/test_labels.py constructs/api/pytest/test_query.py constructs/lib/common/solution-info.ts:30 +constructs/config/batch_create/datasource/template/batch_create_jdbc_datasource-cn.xlsx +constructs/config/batch_create/datasource/template/batch_create_jdbc_datasource-en.xlsx +constructs/config/batch_create/identifier/template/batch_create_identifier-cn.xlsx +constructs/config/batch_create/identifier/template/batch_create_identifier-en.xlsx [python-pipoutdated] pip=v21.1.2 diff --git a/source/constructs/api/README.md b/source/constructs/api/README.md index 1741f02a..d8dff21a 100644 --- a/source/constructs/api/README.md +++ b/source/constructs/api/README.md @@ -23,16 +23,31 @@ As shown in the following format. Please follow this configuration. Select 'Othe ```json {"username":"db username","password":"db password","engine":"mysql","host":"127.0.0.1","port":6306} ``` -### 4. Configuration development mode -The following command is configured as development mode to bypass authentication. +### 4. Configure development parameters ```shell +export AdminBucketName="Your admin bucket name" +# The two subnets where the API lambda is located +export SubnetIds="subnet-xxxxxx,subnet-xxxxxx" +# Development mode that bypasses authentication export mode=dev ``` -### 5. Starting web services locally +### 5. Run as API role +Firstly, use `aws configure -- profile cn` to configure a user authorization information. This user needs to have `sts:AssumeRole` permission. +Secondly, modify the trust relationships of the SDPS API role in the Admin account and add the first step user in the principal. +Thirdly, modify `.aws/config` file and configure the default profile using the following command +``` +[default] +region = cn-northwest-1 +source_profile = cn +role_arn = arn:aws-cn:iam::{AdminAccountId}:role/SDPSAPIRole-cn-northwest-1 +output = json +``` +Finally, validate using `aws sts get-caller-identity`. If the returned content contains `arn:aws-cn:sts::{AdminAccountId}:assumed-role/SDPSAPIRole-{Region}`, it indicates that the configuration is correct. +### 6. Starting web services locally ```shell uvicorn main:app --reload ``` -### 6. View API +### 7. View API http://127.0.0.1:8000/docs ## File Naming diff --git a/source/constructs/api/catalog/crud.py b/source/constructs/api/catalog/crud.py index bbf55c67..c192294b 100644 --- a/source/constructs/api/catalog/crud.py +++ b/source/constructs/api/catalog/crud.py @@ -669,6 +669,16 @@ def delete_catalog_table_level_classification_by_database_region(database: str, ).delete() session.commit() +def delete_catalog_table_level_classification_by_database_region_batch(database: str, region: str, type: str): + session = get_session() + session.query(models.CatalogTableLevelClassification).filter( + models.CatalogTableLevelClassification.database_name == database, + models.CatalogTableLevelClassification.database_type == type + ).filter( + models.CatalogTableLevelClassification.region == region + ).delete(synchronize_session=False) + session.commit() + def delete_catalog_table_level_classification_by_database(database: str, region: str, type: str): session = get_session() @@ -698,6 +708,16 @@ def delete_catalog_database_level_classification_by_database_region(database: st ).delete() session.commit() +def delete_catalog_database_level_classification_by_database_region_batch(database: str, region: str, type: str): + session = get_session() + session.query(models.CatalogDatabaseLevelClassification).filter( + models.CatalogDatabaseLevelClassification.database_name == database, + models.CatalogDatabaseLevelClassification.database_type == type + ).filter( + models.CatalogDatabaseLevelClassification.region == region + ).delete(synchronize_session=False) + session.commit() + def delete_catalog_column_level_classification_by_database_region(database: str, region: str, type: str): session = get_session() @@ -709,6 +729,16 @@ def delete_catalog_column_level_classification_by_database_region(database: str, ).delete() session.commit() +def delete_catalog_column_level_classification_by_database_region_batch(database: str, region: str, type: str): + session = get_session() + session.query(models.CatalogColumnLevelClassification).filter( + models.CatalogColumnLevelClassification.database_name == database, + models.CatalogDatabaseLevelClassification.database_type == type + ).filter( + models.CatalogColumnLevelClassification.region == region + ).delete(synchronize_session=False) + session.commit() + def delete_catalog_column_level_classification_by_database(database: str, region: str, type: str): session = get_session() @@ -1056,11 +1086,14 @@ def update_catalog_table_labels( def get_export_catalog_data(): - return get_session().query(models.CatalogColumnLevelClassification.account_id, + return get_session().query(models.CatalogColumnLevelClassification.database_type, + models.CatalogColumnLevelClassification.account_id, models.CatalogColumnLevelClassification.region, - models.CatalogColumnLevelClassification.database_type, models.CatalogColumnLevelClassification.database_name, + models.CatalogDatabaseLevelClassification.description, + models.CatalogDatabaseLevelClassification.url, models.CatalogColumnLevelClassification.table_name, + models.CatalogTableLevelClassification.storage_location, models.CatalogColumnLevelClassification.column_name, models.CatalogColumnLevelClassification.column_path, models.CatalogColumnLevelClassification.identifier, diff --git a/source/constructs/api/catalog/main.py b/source/constructs/api/catalog/main.py index 4e980182..03e91a4a 100644 --- a/source/constructs/api/catalog/main.py +++ b/source/constructs/api/catalog/main.py @@ -395,7 +395,6 @@ def agg_catalog_summary_by_privacy(database_type: str): @router.get("/dashboard/agg-catalog-top-n", response_model=BaseResponse) @inject_session def agg_catalog_top_n(database_type: str, top_n: int): - return service_dashboard.agg_catalog_data_source_top_n(database_type, top_n) diff --git a/source/constructs/api/catalog/service.py b/source/constructs/api/catalog/service.py index 296f4246..2acabf15 100644 --- a/source/constructs/api/catalog/service.py +++ b/source/constructs/api/catalog/service.py @@ -222,6 +222,7 @@ def sync_crawler_result( database_type: str, database_name: str, ): + database_name = database_name.strip() logger.info(f"start params {account_id} {region} {database_type} {database_name}") start_time = time.time() rds_engine_type = const.NA @@ -243,7 +244,7 @@ def sync_crawler_result( ) if jdbc_database: jdbc_engine_type = jdbc_database.jdbc_connection_url.split(':')[1] - + if need_change_account_id(database_type): glue_client = get_boto3_client(admin_account_id, admin_region, "glue") else: @@ -294,7 +295,7 @@ def sync_crawler_result( table_create_list = [] table_update_list = [] for table in tables_response["TableList"]: - table_name = table["Name"].strip() + table_name = table["Location"].strip() # If the file is end of .csv or .json, but the content of the file is not csv/json # glue can't crawl them correctly # So there is no sizeKey in Parameters, we set the default value is 0 @@ -430,7 +431,7 @@ def sync_crawler_result( if delete_glue_table_names: logger.info("batch delete glue tables" + json.dumps(delete_glue_table_names)) glue_client.batch_delete_table(DatabaseName=glue_database_name, - TablesToDelete=delete_glue_table_names) + TablesToDelete=delete_glue_table_names) except Exception as err: logger.exception("batch delete glue tables error" + str(err)) if logger.isEnabledFor(logging.DEBUG): @@ -510,11 +511,17 @@ def sync_crawler_result( storage_location = rds_engine_type elif database_type == DatabaseType.GLUE.value: storage_location = const.NA + connection_info = glue_client.get_connection(Name=f"{const.SOLUTION_NAME}-{database_type}-{database_name}")['Connection'] if database_type.startswith(DatabaseType.JDBC.value) else {} + description = connection_info.get('Description', '') + url = connection_info.get('ConnectionProperties', {}).get('JDBC_CONNECTION_URL', '') + logger.info(f"connection_info description url!!!!!!!!!!:{description},{url}") catalog_database_dict = { "account_id": account_id, "region": region, "database_type": database_type, "database_name": database_name, + "description": description, + "url": url, "object_count": database_object_count, # not error , logic change when 1.1.0 "size_key": database_size, @@ -532,6 +539,7 @@ def sync_crawler_result( } original_database = crud.get_catalog_database_level_classification_by_name(account_id, region, database_type, database_name) + logger.info(f"original_database is {original_database}") if original_database == None: crud.create_catalog_database_level_classification(catalog_database_dict) else: @@ -715,7 +723,7 @@ def __query_job_result_by_athena( else: time.sleep(1) athena_result_list = [] - next_token = '' + next_token = None while True: if next_token: result = client.get_query_results(QueryExecutionId=query_id, NextToken=next_token) @@ -723,9 +731,8 @@ def __query_job_result_by_athena( result = client.get_query_results(QueryExecutionId=query_id) athena_result_list.append(result) # logger.info(result) - if "NextToken" in result and result['NextToken']: - next_token = result['NextToken'] - else: + next_token = response.get('NextToken') + if not next_token: break # result = client.get_query_results(QueryExecutionId=query_id, NextToken='') @@ -761,7 +768,6 @@ def __convert_identifiers_to_dict(identifiers: str): result_dict[i[0]] = i[1] return result_dict - def sync_job_detection_result( account_id: str, region: str, @@ -848,7 +854,7 @@ def sync_job_detection_result( not overwrite and catalog_column.manual_tag != const.MANUAL)): column_dict = { "id": catalog_column.id, - "identifier": json.dumps(identifier_dict), + "identifier": json.dumps(identifier_dict, ensure_ascii=False), "column_value_example": column_sample_data, "column_path": column_path, "privacy": column_privacy, @@ -1192,11 +1198,12 @@ def delete_catalog_by_account_region(account_id: str, region: str): ) try: crud.delete_catalog_column_level_classification_by_account_region(account_id, region) - except Exception: - raise BizException( - MessageEnum.CATALOG_COLUMN_DELETE_FAILED.get_code(), - MessageEnum.CATALOG_COLUMN_DELETE_FAILED.get_msg(), - ) + except Exception as e: + logger.info(f"{str(e)}") + # raise BizException( + # MessageEnum.CATALOG_COLUMN_DELETE_FAILED.get_code(), + # MessageEnum.CATALOG_COLUMN_DELETE_FAILED.get_msg(), + # ) return True @@ -1217,11 +1224,37 @@ def delete_catalog_by_database_region(database: str, region: str, type: str): ) try: crud.delete_catalog_column_level_classification_by_database_region(database, region, type) + except Exception as e: + logger.info(f"{str(e)}") + # raise BizException( + # MessageEnum.CATALOG_COLUMN_DELETE_FAILED.get_code(), + # MessageEnum.CATALOG_COLUMN_DELETE_FAILED.get_msg(), + # ) + return True + +def delete_catalog_by_database_region_batch(database: str, region: str, type: str): + try: + crud.delete_catalog_database_level_classification_by_database_region_batch(database, region, type) except Exception: raise BizException( - MessageEnum.CATALOG_COLUMN_DELETE_FAILED.get_code(), - MessageEnum.CATALOG_COLUMN_DELETE_FAILED.get_msg(), + MessageEnum.CATALOG_DATABASE_DELETE_FAILED.get_code(), + MessageEnum.CATALOG_DATABASE_DELETE_FAILED.get_msg(), + ) + try: + crud.delete_catalog_table_level_classification_by_database_region_batch(database, region, type) + except Exception: + raise BizException( + MessageEnum.CATALOG_TABLE_DELETE_FAILED.get_code(), + MessageEnum.CATALOG_TABLE_DELETE_FAILED.get_msg(), ) + try: + crud.delete_catalog_column_level_classification_by_database_region_batch(database, region, type) + except Exception as e: + logger.info(f"{str(e)}") + # raise BizException( + # MessageEnum.CATALOG_COLUMN_DELETE_FAILED.get_code(), + # MessageEnum.CATALOG_COLUMN_DELETE_FAILED.get_msg(), + # ) return True @@ -1386,33 +1419,49 @@ def filter_records(all_items: list, all_labels_dict: dict, sensitive_flag: str): jdbc_records = [] for row in all_items: row_result = [cell for cell in row] - if sensitive_flag != 'all' and "N/A" in row_result[7]: + if sensitive_flag != 'all' and "N/A" in row_result[10]: continue - if row_result[9]: - row_result[9] = ",".join(gen_labels(all_labels_dict, row_result[9])) - if row_result[10]: - row_result[10] = ",".join(gen_labels(all_labels_dict, row_result[10])) - catalog_type = row_result[2] + if row_result[12]: + row_result[12] = ",".join(gen_labels(all_labels_dict, row_result[12])) + if row_result[13]: + row_result[13] = ",".join(gen_labels(all_labels_dict, row_result[13])) + catalog_type = row_result[0] if catalog_type == DatabaseType.S3.value: + del row_result[0] + del row_result[3] + del row_result[3] + del row_result[3] + del row_result[4] + del row_result[4] s3_records.append([row_result]) elif catalog_type == DatabaseType.S3_UNSTRUCTURED.value: + del row_result[0] + del row_result[3] + del row_result[3] + del row_result[3] + del row_result[4] + del row_result[4] s3_unstructured_records.append([row_result]) elif catalog_type == DatabaseType.RDS.value: - del row_result[6] + del row_result[0] + del row_result[3] + del row_result[3] + del row_result[4] + del row_result[5] rds_records.append([row_result]) elif catalog_type == DatabaseType.GLUE.value: - del row_result[6] + del row_result[0] + del row_result[3] + del row_result[3] + del row_result[4] + del row_result[5] glue_records.append([row_result]) elif catalog_type.startswith(DatabaseType.JDBC.value): - del row_result[6] + del row_result[7] + del row_result[8] jdbc_records.append([row_result]) else: pass - # return {const.EXPORT_S3_MARK_STR: {const.EXPORT_S3_SHEET_TITLE: s3_records, - # const.EXPORT_S3_UNSTRUCTURED_SHEET_TITLE: s3_unstructured_records}, - # const.EXPORT_RDS_MARK_STR: {const.EXPORT_RDS_SHEET_TITLE: rds_records}, - # const.EXPORT_GLUE_MARK_STR: {const.EXPORT_RDS_SHEET_TITLE: glue_records}, - # const.EXPORT_JDBC_MARK_STR: {const.EXPORT_RDS_SHEET_TITLE: jdbc_records}} return {const.EXPORT_S3_MARK_STR: s3_records, const.EXPORT_S3_UNSTRUCTURED_MARK_STR: s3_unstructured_records, const.EXPORT_RDS_MARK_STR: rds_records, @@ -1434,52 +1483,16 @@ def gen_zip_file(header, record, tmp_filename, type): batches = int(len(v) / const.EXPORT_XLSX_MAX_LINES) if batches < 1: __gen_xlsx_file(k, header.get(k), v, 0, zipf) - # wb = Workbook() - # ws1 = wb.active - # ws1.title = k - # ws1.append(header.get(k)) - # for row_index in range(0, len(v)): - # ws1.append([__get_cell_value(cell) for cell in v[row_index][0]]) - # file_name = f"{tmp_folder}/{k}.xlsx" - # wb.save(file_name) - # zipf.write(file_name, os.path.abspath(file_name)) - # os.remove(file_name) else: for i in range(0, batches + 1): __gen_xlsx_file(f"{k}_{i+1}", header.get(k), v, const.EXPORT_XLSX_MAX_LINES * i, zipf) - # wb = Workbook() - # ws1 = wb.active - # ws1.title = k - # ws1.append(header.get(k)) - # for row_index in range(const.EXPORT_XLSX_MAX_LINES * i, min(const.EXPORT_XLSX_MAX_LINES * (i + 1), len(v))): - # ws1.append([__get_cell_value(cell) for cell in v[row_index][0]]) - # file_name = f"{tmp_folder}/{k}_{i+1}.xlsx" - # wb.save(file_name) - # zipf.write(file_name, os.path.basename(file_name)) - # os.remove(file_name) else: batches = int(len(v) / const.EXPORT_CSV_MAX_LINES) if batches < 1: __gen_csv_file(k, header.get(k), v, 0, zipf) - # file_name = f"{tmp_folder}/{k}.csv" - # with open(file_name, 'w', encoding="utf-8-sig", newline='') as csv_file: - # csv_writer = csv.writer(csv_file) - # csv_writer.writerow(header.get(k)) - # for record in v: - # csv_writer.writerow([__get_cell_value(cell) for cell in record[0]]) - # zipf.write(file_name, os.path.abspath(file_name)) - # os.remove(file_name) else: for i in range(0, batches + 1): __gen_csv_file(f"{k}_{i+1}", header.get(k), v, const.EXPORT_CSV_MAX_LINES * i, zipf) - # file_name = f"{tmp_folder}/{k}_{i+1}.csv" - # with open(file_name, 'w', encoding="utf-8-sig", newline='') as csv_file: - # csv_writer = csv.writer(csv_file) - # csv_writer.writerow(header.get(k)) - # for record in v[const.EXPORT_CSV_MAX_LINES * i: min(const.EXPORT_CSV_MAX_LINES * (i + 1), len(v))]: - # csv_writer.writerow([__get_cell_value(cell) for cell in record[0]]) - # zipf.write(file_name, os.path.abspath(file_name)) - # os.remove(file_name) def __get_cell_value(cell: dict): if isinstance(cell, datetime): diff --git a/source/constructs/api/catalog/service_dashboard.py b/source/constructs/api/catalog/service_dashboard.py index bae51eba..6fe3d54e 100644 --- a/source/constructs/api/catalog/service_dashboard.py +++ b/source/constructs/api/catalog/service_dashboard.py @@ -8,17 +8,14 @@ Provider ) from common.constant import const -import logging from common.exception_handler import BizException import heapq from common.query_condition import QueryCondition - - -logger = logging.getLogger("api") +from common.reference_parameter import logger def agg_data_source_summary(provider_id): - if provider_id == Provider.AWS_CLOUD.value: + if provider_id == str(Provider.AWS_CLOUD.value): account_set, region_set = count_aws_account_region() # Get data source total region. else: @@ -28,17 +25,14 @@ def agg_data_source_summary(provider_id): return result_dict def count_aws_account_region(): - s3_account_region = data_source_crud.get_source_s3_account_region() - rds_account_region = data_source_crud.get_source_rds_account_region() + # s3_account_region = data_source_crud.get_source_s3_account_region() + # rds_account_region = data_source_crud.get_source_rds_account_region() + aws_account_region = data_source_crud.get_source_aws_account_region() account_set = set() region_set = set() - for d in s3_account_region: - account_set.add(d['aws_account']) - region_set.add(d['region']) - - for d in rds_account_region: - account_set.add(d['aws_account']) + for d in aws_account_region: + account_set.add(d['account_id']) region_set.add(d['region']) return account_set, region_set @@ -158,8 +152,8 @@ def agg_catalog_data_source_top_n(database_type: str, top_n: int): for table in table_rows: if table.identifiers == const.NA: continue - data_source_full_name = table.account_id + table.region + table.database_type + table.database_name - + type = 's3' if table.database_type == 'unstructured' else table.database_type + data_source_full_name = table.account_id + table.region + type + table.database_name table_identifiers = table.identifiers.split("|") for identifier in table_identifiers: if identifier == const.NA or identifier == "": @@ -175,8 +169,6 @@ def agg_catalog_data_source_top_n(database_type: str, top_n: int): if database.account_id not in account_dict: account_dict[database.account_id]=set() account_dict[database.account_id].add(data_source_full_name) - - result_dict['account_top_n'] = __get_top_n_count(account_dict, top_n) logger.debug(identifier_dict.keys()) @@ -290,15 +282,28 @@ def get_database_by_identifier_paginate_s3(condition: QueryCondition): identifier = con.values[0] table_list = crud.get_s3_catalog_table_level_classification_by_identifier(identifier) for table in table_list: - database_full_name = table.account_id + "|" + table.region + "|" + table.database_type + "|" + table.database_name + type = 's3' if table.database_type == 'unstructured' else table.database_type + database_full_name = table.account_id + "|" + table.region + "|" + type + "|" + table.database_name database_set.add(database_full_name) database_list = sorted(list(database_set)) for database_full_name in database_list: database_info = database_full_name.split("|") + db_type = database_info[2] result_db = crud.get_catalog_database_level_classification_by_name(database_info[0], - database_info[1], - database_info[2], - database_info[3]) + database_info[1], + database_info[2], + database_info[3]) + if db_type == "s3": + result_db_unstructured = crud.get_catalog_database_level_classification_by_name(database_info[0], + database_info[1], + 'unstructured', + database_info[3]) + if result_db_unstructured: + if result_db: + result_db.size_key += result_db_unstructured.size_key + result_db.object_count += result_db_unstructured.object_count + else: + result_db = result_db_unstructured if result_db: result_list.append(result_db) if condition.size >= len(result_list): diff --git a/source/constructs/api/common/abilities.py b/source/constructs/api/common/abilities.py index a74c6c36..b82eaed8 100644 --- a/source/constructs/api/common/abilities.py +++ b/source/constructs/api/common/abilities.py @@ -1,6 +1,9 @@ from common.enum import (Provider, ProviderName, DatabaseType) +from common.reference_parameter import logger, admin_account_id +from common.constant import const +from openpyxl.styles import Font, PatternFill def convert_database_type_2_provider(database_type: str) -> int: @@ -43,6 +46,15 @@ def need_change_account_id(database_type: str) -> bool: return True return False + +def is_run_in_admin_vpc(database_type: str, account_id: str = None) -> bool: + if database_type == DatabaseType.JDBC_AWS.value: + return account_id == admin_account_id + elif database_type.startswith(DatabaseType.JDBC.value): + return True + return False + + def query_all_vpc(ec2_client): vpcs = [] response = ec2_client.describe_vpcs() @@ -51,3 +63,15 @@ def query_all_vpc(ec2_client): response = ec2_client.describe_vpcs(NextToken=response['NextToken']) vpcs.append(response['Vpcs']) return vpcs[0] + +def insert_error_msg_2_cells(sheet, row_index, msg, res_column_index): + if msg in [const.EXISTED_MSG, const.IDENTIFIER_EXISTED_MSG]: + sheet.cell(row=row_index + 1, column=res_column_index, value="WARNING") + sheet.cell(row=row_index + 1, column=res_column_index).font = Font(color='563112', bold=True) + else: + sheet.cell(row=row_index + 1, column=res_column_index, value="FAILED") + sheet.cell(row=row_index + 1, column=res_column_index).font = Font(color='FF0000', bold=True) + sheet.cell(row=row_index + 1, column=res_column_index + 1, value=msg) + +def insert_success_2_cells(sheet, row_index, res_column_index): + sheet.cell(row=row_index + 1, column=res_column_index, value="SUCCESSED") \ No newline at end of file diff --git a/source/constructs/api/common/constant.py b/source/constructs/api/common/constant.py index 0a549ab0..23631dea 100644 --- a/source/constructs/api/common/constant.py +++ b/source/constructs/api/common/constant.py @@ -64,15 +64,15 @@ def __setattr__(self, name, value): const.SYSTEM = 'system' const.SAMPLE_LIMIT = 1000 const.LAMBDA_MAX_RUNTIME = 900 -const.EXPORT_FILE_S3_COLUMNS = ["account_id", "region", "type", "s3_bucket", "folder_name", "column_name", "column_path", "identifiers", "sample_data", +const.EXPORT_FILE_S3_COLUMNS = ["account_id", "region", "bucket_name", "location", "identifiers", "sample_data", "bucket_catalog_label", "folder_catalog_label", "comment", "last_updated_at", "last_updated_by"] -const.EXPORT_FILE_S3_UNSTRUCTURED_COLUMNS = ["account_id", "region", "type", "s3_bucket", "folder_name", "sample_object_name", "s3_location", +const.EXPORT_FILE_S3_UNSTRUCTURED_COLUMNS = ["account_id", "region", "bucket_name", "location", "identifiers", "sample_data", "bucket_catalog_label", "folder_catalog_label", "comment", "last_updated_at", "last_updated_by"] -const.EXPORT_FILE_RDS_COLUMNS = ["account_id", "region", "type", "rds_instance_id", "table_name", "column_name", "identifiers", "sample_data", +const.EXPORT_FILE_RDS_COLUMNS = ["account_id", "region", "instance_name", "table_name", "column_name", "identifiers", "sample_data", "instance_catalog_label", "table_catalog_label", "comment", "last_updated_at", "last_updated_by"] -const.EXPORT_FILE_GLUE_COLUMNS = ["account_id", "region", "type", "glue_database", "table_name", "column_name", "identifiers", "sample_data", +const.EXPORT_FILE_GLUE_COLUMNS = ["account_id", "region", "database_name", "table_name", "column_name", "identifiers", "sample_data", "instance_catalog_label", "table_catalog_label", "comment", "last_updated_at", "last_updated_by"] -const.EXPORT_FILE_JDBC_COLUMNS = ["account_id", "region", "type", "jdbc_connection", "table_name", "column_name", "identifiers", "sample_data", +const.EXPORT_FILE_JDBC_COLUMNS = ["type", "account_id", "region", "instance_name", "description", "jdbc_url", "table_name", "column_name", "identifiers", "sample_data", "instance_catalog_label", "table_catalog_label", "comment", "last_updated_at", "last_updated_by"] const.EXPORT_XLSX_MAX_LINES = 30000 const.EXPORT_CSV_MAX_LINES = 60000 @@ -96,6 +96,46 @@ def __setattr__(self, name, value): const.PUBLIC = 'Public' const.PRIVATE = 'Private' const.ZERO = 0 +const.BATCH_CREATE_LIMIT = 1000 +const.BATCH_SHEET = "OriginTemplate" +const.REGION_CORD = "region_cord" +const.REGION_ALIAS = "region_alias" + +const.CONNECTION_DESC_MAX_LEN = 2048 +const.BATCH_CREATE_TEMPLATE_PATH_CN = 'batch_create/datasource/template/batch_create_jdbc_datasource-cn.xlsx' +const.BATCH_CREATE_TEMPLATE_PATH_EN = 'batch_create/datasource/template/batch_create_jdbc_datasource-en.xlsx' +const.BATCH_CREATE_IDENTIFIER_TEMPLATE_PATH_CN = 'batch_create/identifier/template/batch_create_identifier-cn.xlsx' +const.BATCH_CREATE_IDENTIFIER_TEMPLATE_PATH_EN = 'batch_create/identifier/template/batch_create_identifier-en.xlsx' +const.BATCH_CREATE_REPORT_PATH = 'batch_create/datasource/report' +const.BATCH_CREATE_IDENTIFIER_REPORT_PATH = 'batch_create/identifier/report' +const.EXISTED_MSG = 'JDBC connection with the same instance already exists' +const.IDENTIFIER_EXISTED_MSG = "A data identifier with the same name already exists" +const.DATASOURCE_REPORT = "batch_export/datasource" +const.IDENTIFY_REPORT = "batch_export/identify" + +const.CONFIG_CONCURRENT_RUN_JOB_NUMBER = 'ConcurrentRunJobNumber' +const.CONFIG_CONCURRENT_RUN_JOB_NUMBER_DEFAULT_VALUE = 50 +const.CONFIG_SUB_JOB_NUMBER_S3 = 'SubJobNumberS3' +const.CONFIG_SUB_JOB_NUMBER_S3_DEFAULT_VALUE = 10 +const.CONFIG_SUB_JOB_NUMBER_RDS = 'SubJobNumberRds' +const.CONFIG_SUB_JOB_NUMBER_RDS_DEFAULT_VALUE = 3 +const.CONTROLLER_ACTION = 'Action' +const.CONTROLLER_ACTION_SCHEDULE_JOB = 'ScheduleJob' +const.CONTROLLER_ACTION_CHECK_RUNNING_RUN_DATABASES = 'CheckRunningRunDatabases' +const.CONTROLLER_ACTION_CHECK_PENDING_RUN_DATABASES = 'CheckPendingRunDatabases' +const.CONTROLLER_ACTION_REFRESH_ACCOUNT = 'RefreshAccount' + +const.EXPORT_DS_HEADER_S3 = ["account_id", "region", "bucket_name", "crawler_status", "last_updated_at", "last_updated_by"] +const.EXPORT_DS_HEADER_RDS = ["account_id", "region", "instance_name", "engine_type", "location", "crawler_status", "last_updated_at", "last_updated_by"] +const.EXPORT_DS_HEADER_GLUE = ["account_id", "region", "database_name", "description", "location", "crawler_status", "last_updated_at", "last_updated_by"] +const.EXPORT_DS_HEADER_JDBC = ["type", "account_id", "region", "instance_name", "description", "location", "crawler_status", "last_updated_at", "last_updated_by"] +const.S3_STR = "S3" +const.RDS_STR = "RDS" +const.GLUE_STR = "GLUE" +const.JDBC_STR = "JDBC" +# const.DATASOURCE_REPORT = "report/datasource" +# const.IDENTIFY_REPORT = "report/identify" +const.EXPORT_IDENTIFY_HEADER = ["Data identify name", "Description", "Rule", "header_keywords", "exclude_keywords", "max_distance", "min_occurrence", "Identify Category", "Identify label"] const.UNSTRUCTURED_FILES = { "document": ["doc", "docx", "pdf", "ppt", "pptx", "xls", "xlsx", "odp"], diff --git a/source/constructs/api/common/enum.py b/source/constructs/api/common/enum.py index 1cfd76fa..f34a6468 100644 --- a/source/constructs/api/common/enum.py +++ b/source/constructs/api/common/enum.py @@ -1,5 +1,7 @@ from enum import Enum, unique +from common.constant import const + # system 1000 ~ 1099 # user 1100 ~ 1199 @@ -21,7 +23,7 @@ class MessageEnum(Enum): # template TEMPLATE_NOT_EXISTS = {1401: "The classification template does not exist"} TEMPLATE_IDENTIFIER_NOT_EXISTS = {1402: "The data identifier does not exist"} - TEMPLATE_IDENTIFIER_EXISTS = {1403: "A data identifier with the same name already exists"} + TEMPLATE_IDENTIFIER_EXISTS = {1403: const.IDENTIFIER_EXISTED_MSG} TEMPLATE_IDENTIFIER_USED = {1404: "The data identifier is being used"} TEMPLATE_PROPS_USED = {1405: "The item is being used"} TEMPLATE_PROPS_EXISTS = {1406: "A category/regulation with the same name already exists"} @@ -99,7 +101,7 @@ class MessageEnum(Enum): SOURCE_JDBC_NO_CREDENTIAL = {1231: "No credential"} SOURCE_JDBC_NO_AUTH = {1232: "No authorization"} SOURCE_JDBC_DUPLICATE_AUTH = {1233: "Duplicate authorization"} - SOURCE_JDBC_ALREADY_EXISTS = {1234: "JDBC connection with the same instance already exists"} + SOURCE_JDBC_ALREADY_EXISTS = {1234: const.EXISTED_MSG} SOURCE_GLUE_DATABASE_EXISTS = {1235: "Glue database with the same name already exists"} SOURCE_GLUE_DATABASE_NO_INSTANCE = {1236: "Glue database does not exist"} SOURCE_SECURITY_GROUP_NOT_EXISTS = {1237: "Security for jdbc connection is not existed"} @@ -123,8 +125,11 @@ class MessageEnum(Enum): SOURCE_JDBC_ALREADY_IMPORTED = {1255: "JDBC connection with the same instance already be imported"} SOURCE_JDBC_LIST_DATABASES_NOT_SUPPORTED = {1256: "JDBC list databases not supported."} SOURCE_JDBC_LIST_DATABASES_FAILED = {1257: "JDBC list databases failed."} - SOURCE_ACCOUNT_ID_ALREADY_EXISTS = {1256: "A duplicate account with the same name already exists. Please note that account names must be unique."} - + SOURCE_ACCOUNT_ID_ALREADY_EXISTS = {1258: "A duplicate account with the same name already exists. Please note that account names must be unique."} + SOURCE_BATCH_CREATE_FORMAT_ERR = {1259: "Invalid file type, please provide an Excel file (.xlsx)."} + SOURCE_BATCH_CREATE_LIMIT_ERR = {1260: "Batch operation limit exceeded, please ensure that a maximum of 100 data sources are created at a time."} + SOURCE_BATCH_SHEET_NOT_FOUND = {1261: "Sheet [OriginTemplate] not found in the Excel file"} + SOURCE_BATCH_SHEET_NO_CONTENT = {1262: "There is no relevant data in sheet [OriginTemplate], please add data according to the format."} # label LABEL_EXIST_FAILED = {1611: "Cannot create duplicated label"} @@ -164,6 +169,7 @@ class RunState(Enum): @unique class RunDatabaseState(Enum): READY = "Ready" + PENDING = "Pending" RUNNING = "Running" SUCCEEDED = "Succeeded" FAILED = "Failed" @@ -237,6 +243,7 @@ class CatalogModifier(Enum): class ConnectionState(Enum): PENDING = "PENDING" CRAWLING = "CRAWLING" + AUTHORIZED = "AUTHORIZED" ACTIVE = "ACTIVE" UNSUPPORTED = "UNSUPPORTED FILE TYPES" ERROR = "ERROR" @@ -289,8 +296,6 @@ class OperationType(Enum): NOT_CONTAIN = "!:" IN = "in" - - @unique class AutoSyncDataAction(Enum): DELETE_ACCOUNT = "DeleteAccount" @@ -309,7 +314,7 @@ class Provider(Enum): @unique class ProviderName(Enum): - AWS_CLOUD = 'AWS' + AWS_CLOUD = 'AWS CLOUD' TENCENT_CLOUD = 'TENCENT CLOUD' GOOGLE_CLOUD = 'GOOGLE CLOUD' JDBC_PROXY = 'JDBC PROXY' diff --git a/source/constructs/api/common/exception_handler.py b/source/constructs/api/common/exception_handler.py index 9902d5d7..5a1a2790 100644 --- a/source/constructs/api/common/exception_handler.py +++ b/source/constructs/api/common/exception_handler.py @@ -1,11 +1,9 @@ from fastapi import status, FastAPI, Request from fastapi.exceptions import RequestValidationError -import os from common.enum import MessageEnum from common.constant import const import logging from .response_wrapper import resp_err - import traceback logger = logging.getLogger(const.LOGGER_API) @@ -31,8 +29,6 @@ async def exception_handler(req: Request, exc: Exception): if isinstance(exc, BizException): return error_msg = traceback.format_exc() - if os.getenv(const.MODE) != const.MODE_DEV: - error_msg = error_msg.replace("\n", "\r") logger.error(error_msg) return resp_err(MessageEnum.BIZ_UNKNOWN_ERR.get_code(), MessageEnum.BIZ_UNKNOWN_ERR.get_msg()) @@ -46,3 +42,6 @@ def __init__(self, self.code = code self.message = message self.ref = ref + + def __msg__(self): + return self.message diff --git a/source/constructs/api/common/log_formatter.py b/source/constructs/api/common/log_formatter.py new file mode 100644 index 00000000..018e462b --- /dev/null +++ b/source/constructs/api/common/log_formatter.py @@ -0,0 +1,11 @@ +import os +import logging + +is_running_in_lambda = 'AWS_LAMBDA_FUNCTION_NAME' in os.environ + + +class CustomFormatter(logging.Formatter): + def format(self, record): + if is_running_in_lambda: + record.msg = str(record.msg).replace("\n", "\r") + return super().format(record) diff --git a/source/constructs/api/common/reference_parameter.py b/source/constructs/api/common/reference_parameter.py index 94758dd8..a8673dd3 100644 --- a/source/constructs/api/common/reference_parameter.py +++ b/source/constructs/api/common/reference_parameter.py @@ -11,3 +11,4 @@ partition = caller_identity['Arn'].split(':')[1] url_suffix = const.URL_SUFFIX_CN if partition == const.PARTITION_CN else '' public_account_id = const.PUBLIC_ACCOUNT_ID_CN if partition == const.PARTITION_CN else const.PUBLIC_ACCOUNT_ID_GLOBAL +admin_subnet_ids = os.getenv('SubnetIds', '').split(',') diff --git a/source/constructs/api/common/request_wrapper.py b/source/constructs/api/common/request_wrapper.py index b0cda2db..9373e2fd 100644 --- a/source/constructs/api/common/request_wrapper.py +++ b/source/constructs/api/common/request_wrapper.py @@ -2,11 +2,9 @@ from functools import wraps import time from pydantic import BaseModel -import os from typing import Optional from common.constant import const import logging - from db.database import gen_session, close_session from common.response_wrapper import resp_ok from fastapi_pagination.bases import RawParams @@ -26,15 +24,15 @@ def inject_session(func): @wraps(func) def wrapper(*args, **kwargs): start_time = time.time() - newline_character = "\n" if os.getenv(const.MODE) == const.MODE_DEV else "\r" # Parameters may contain sensitive information entered by users, such as database connection information, # so they will not be output in the production environment - logger.debug(f"START >>>{newline_character}METHOD: {func.__name__}{newline_character}PARAMS: {kwargs}") + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"START >>> METHOD: {func.__name__} PARAMS: {kwargs}") try: gen_session() result = func(*args, **kwargs) res = resp_ok(result) - logger.debug(f"END >>> USED:{round(time.time()-start_time)}ms") + logger.debug(f"END >>> USED:{round(time.time()-start_time)}s") return res finally: close_session() diff --git a/source/constructs/api/config/crud.py b/source/constructs/api/config/crud.py index a3872926..09873de7 100644 --- a/source/constructs/api/config/crud.py +++ b/source/constructs/api/config/crud.py @@ -17,3 +17,7 @@ def set_value(key: str, value: str): config_value=value) session.add(db_config) session.commit() + + +def list_config(): + return get_session().query(models.Config).all() diff --git a/source/constructs/api/config/main.py b/source/constructs/api/config/main.py new file mode 100644 index 00000000..211d56f6 --- /dev/null +++ b/source/constructs/api/config/main.py @@ -0,0 +1,24 @@ +from fastapi import APIRouter +from . import service, schemas +from common.request_wrapper import inject_session +from common.response_wrapper import BaseResponse + +router = APIRouter(prefix="/config", tags=["config"]) + + +@router.get("", response_model=BaseResponse[list[schemas.ConfigBase]]) +@inject_session +def list_config(): + return service.list_config() + + +@router.post("") +@inject_session +def set_config(configs: list[schemas.ConfigBase]): + return service.set_configs(configs) + + +@router.get("/subnets", response_model=BaseResponse[list[schemas.SubnetInfo]]) +@inject_session +def list_subnets(): + return service.list_subnets() diff --git a/source/constructs/api/config/schemas.py b/source/constructs/api/config/schemas.py new file mode 100644 index 00000000..ce62f85e --- /dev/null +++ b/source/constructs/api/config/schemas.py @@ -0,0 +1,20 @@ +from typing import Optional +from pydantic import BaseModel +import db.models_config as models + + +class ConfigBase(BaseModel): + config_key: str + config_value: str + + class Meta: + orm_model = models.Config + + class Config: + orm_mode = True + + +class SubnetInfo(BaseModel): + subnet_id: str + name: Optional[str] + available_ip_address_count: int diff --git a/source/constructs/api/config/service.py b/source/constructs/api/config/service.py index 563247c9..4ab32bf0 100644 --- a/source/constructs/api/config/service.py +++ b/source/constructs/api/config/service.py @@ -1,9 +1,46 @@ -from . import crud +from . import crud,schemas +import boto3 +from common.reference_parameter import admin_subnet_ids def set_config(key: str, value: str): - return crud.set_value(key, value) + crud.set_value(key, value) + + +def get_config(key: str, default_value=None) -> str: + _value = crud.get_value(key) + if _value: + return _value + if default_value: + return default_value + return None + + +def list_config(): + return crud.list_config() + + +def set_configs(configs: list[schemas.ConfigBase]): + for config in configs: + set_config(config.config_key, config.config_value) + + +def list_subnets(): + ec2_client = boto3.client('ec2') + response = ec2_client.describe_subnets(SubnetIds=admin_subnet_ids) + subnet_infos = [] + for subnet in response['Subnets']: + subnet_info = schemas.SubnetInfo(subnet_id=subnet['SubnetId'], + name=__get_name(subnet['Tags']), + available_ip_address_count=subnet['AvailableIpAddressCount']) + subnet_infos.append(subnet_info) + return subnet_infos + + +def __get_name(tags: list) -> str: + for tag in tags: + if tag.get("Key") == "Name": + return tag.get("Value") + return None -def get_config(key: str) -> str: - return crud.get_value(key) diff --git a/source/constructs/api/data_source/crud.py b/source/constructs/api/data_source/crud.py index f99516ad..8ef30fac 100644 --- a/source/constructs/api/data_source/crud.py +++ b/source/constructs/api/data_source/crud.py @@ -384,7 +384,7 @@ def create_s3_connection(account: str, region: str, bucket: str, glue_connection s3_bucket_source.glue_database = glue_database_name s3_bucket_source.glue_crawler = crawler_name s3_bucket_source.glue_crawler_last_updated = datetime.datetime.utcnow() - s3_bucket_source.glue_state = ConnectionState.ACTIVE.value + s3_bucket_source.glue_state = ConnectionState.AUTHORIZED.value session.merge(s3_bucket_source) session.commit() @@ -514,7 +514,7 @@ def create_rds_connection(account: str, RdsInstanceSource.region == region, RdsInstanceSource.account_id == account).order_by( desc(RdsInstanceSource.detection_history_id)).first() - if rds_instance_source is None: + if not rds_instance_source: rds_instance_source = RdsInstanceSource(instance_id=instance, region=region, account_id=account) rds_instance_source.glue_database = glue_database @@ -522,7 +522,7 @@ def create_rds_connection(account: str, rds_instance_source.glue_connection = glue_connection rds_instance_source.glue_vpc_endpoint = glue_vpc_endpoint_id rds_instance_source.glue_crawler_last_updated = datetime.datetime.utcnow() - rds_instance_source.glue_state = ConnectionState.CRAWLING.value + rds_instance_source.glue_state = ConnectionState.AUTHORIZED.value session.merge(rds_instance_source) session.commit() @@ -645,12 +645,12 @@ def update_jdbc_connection_full(jdbc_instance: schemas.JDBCInstanceSourceUpdate) jdbc_instance_source.network_sg_id = jdbc_instance.network_sg_id jdbc_instance_source.jdbc_driver_class_name = jdbc_instance.jdbc_driver_class_name jdbc_instance_source.jdbc_driver_jar_uri = jdbc_instance.jdbc_driver_jar_uri - jdbc_instance_source.glue_database = None - jdbc_instance_source.glue_crawler = None + # jdbc_instance_source.glue_database = None + # jdbc_instance_source.glue_crawler = None jdbc_instance_source.glue_connection = jdbc_instance_source.glue_connection - jdbc_instance_source.glue_vpc_endpoint = None + # jdbc_instance_source.glue_vpc_endpoint = None jdbc_instance_source.glue_crawler_last_updated = datetime.datetime.utcnow() - jdbc_instance_source.glue_state = None + jdbc_instance_source.glue_state = ConnectionState.AUTHORIZED.value session.merge(jdbc_instance_source) session.commit() @@ -805,17 +805,25 @@ def add_third_account(account, role_arn): return True -def get_source_s3_account_region(): - return (get_session() - .query(S3BucketSource.region, S3BucketSource.account_id) - .distinct() - .all() - ) +# def get_source_s3_account_region(): +# return (get_session() +# .query(S3BucketSource.region, S3BucketSource.account_id) +# .distinct() +# .all() +# ) +# def get_source_proxy_account_region(): +# return (get_session() +# .query(Account.region, Account.account_id) +# .filter(Account.account_provider_id == Provider.JDBC_PROXY.value) +# .distinct() +# .all() +# ) -def get_source_rds_account_region(): +def get_source_aws_account_region(): return (get_session() - .query(RdsInstanceSource.region, RdsInstanceSource.account_id) + .query(Account.region, Account.account_id) + .filter(or_(Account.account_provider_id == Provider.AWS_CLOUD.value, Account.account_provider_id == Provider.JDBC_PROXY.value )) .distinct() .all() ) @@ -859,19 +867,14 @@ def copy_properties(jdbc_instance_target: JDBCInstanceSource, jdbc_instance_orig jdbc_instance_target.jdbc_driver_class_name = jdbc_instance_origin.jdbc_driver_class_name jdbc_instance_target.jdbc_driver_jar_uri = jdbc_instance_origin.jdbc_driver_jar_uri jdbc_instance_target.detection_history_id = 0 - # jdbc_instance_target.instance_class = jdbc_instance_origin.instance_class - # jdbc_instance_target.instance_status = jdbc_instance_origin.instance_status jdbc_instance_target.account_provider_id = jdbc_instance_origin.account_provider_id jdbc_instance_target.account_id = jdbc_instance_origin.account_id jdbc_instance_target.region = jdbc_instance_origin.region - # jdbc_instance_target.data_source_id = jdbc_instance_origin.data_source_id - # jdbc_instance_target.detection_history_id = jdbc_instance_origin.detection_history_id - # jdbc_instance_target.glue_database = jdbc_instance_origin.glue_database - # jdbc_instance_target.glue_crawler = jdbc_instance_origin.glue_crawler + jdbc_instance_target.glue_database = jdbc_instance_origin.glue_database + jdbc_instance_target.glue_crawler = jdbc_instance_origin.glue_crawler jdbc_instance_target.glue_connection = jdbc_instance_origin.glue_connection - # jdbc_instance_target.glue_vpc_endpoint = jdbc_instance_origin.glue_vpc_endpoint - # jdbc_instance_target.glue_state = jdbc_instance_origin.glue_state jdbc_instance_target.create_type = jdbc_instance_origin.create_type + jdbc_instance_target.glue_state = ConnectionState.AUTHORIZED.value return jdbc_instance_target def add_jdbc_conn(jdbcConn: schemas.JDBCInstanceSourceFullInfo): @@ -931,6 +934,11 @@ def get_account_list_by_provider(provider_id): return get_session().query(Account).filter(Account.account_provider_id == provider_id, Account.status == SourceAccountStatus.ENABLE.value).all() +def get_enable_account_list(): + return get_session().query(Account).filter(Account.status == SourceAccountStatus.ENABLE.value).all() + +def get_enable_region_list(): + return get_session().query(SourceRegion).filter(SourceRegion.status == SourceRegionStatus.ENABLE.value).all() def list_distinct_region_by_provider(provider_id) -> list[SourceRegion]: return get_session().query(SourceRegion).filter(SourceRegion.provider_id == provider_id, @@ -959,3 +967,126 @@ def get_total_glue_database_count(): def get_connected_glue_database_count(): list = list_glue_database_source_without_condition() return 0 if not list else list.filter(SourceGlueDatabase.glue_state == ConnectionState.ACTIVE.value).count() + + +def get_schema_by_snapshot(provider_id, account_id, region, instance): + return get_session().query(JDBCInstanceSource.jdbc_connection_schema, JDBCInstanceSource.network_subnet_id) \ + .filter(JDBCInstanceSource.account_provider_id == provider_id) \ + .filter(JDBCInstanceSource.account_id == account_id) \ + .filter(JDBCInstanceSource.instance_id == instance) \ + .filter(JDBCInstanceSource.region == region).first() + + +def get_connection_by_instance(provider_id, account_id, region, instance): + return get_session().query(JDBCInstanceSource.glue_connection) \ + .filter(JDBCInstanceSource.account_provider_id == provider_id) \ + .filter(JDBCInstanceSource.account_id == account_id) \ + .filter(JDBCInstanceSource.instance_id == instance) \ + .filter(JDBCInstanceSource.region == region).first() + + +def get_crawler_glue_db_by_instance(provider_id, account_id, region, instance): + return get_session().query(JDBCInstanceSource.glue_crawler, JDBCInstanceSource.glue_database, JDBCInstanceSource.glue_connection) \ + .filter(JDBCInstanceSource.account_provider_id == provider_id) \ + .filter(JDBCInstanceSource.account_id == account_id) \ + .filter(JDBCInstanceSource.instance_id == instance) \ + .filter(JDBCInstanceSource.region == region).first() + +def get_enable_account_list(): + return get_session().query(Account.account_provider_id, Account.account_id, Account.region) \ + .filter(Account.status == SourceAccountStatus.ENABLE.value).all() + +def update_schema_by_account(provider_id, account_id, instance, region, schema): + session = get_session() + jdbc_instance_source = session.query(JDBCInstanceSource).filter(JDBCInstanceSource.account_provider_id == provider_id, + JDBCInstanceSource.region == region, + JDBCInstanceSource.account_id == account_id, + JDBCInstanceSource.instance_id == instance).first() + if jdbc_instance_source: + jdbc_instance_source.jdbc_connection_schema = schema + session.commit() + +def list_s3_resources(account_id, region, condition): + session_result = get_session().query(S3BucketSource) + if account_id: + session_result = session_result.filter(S3BucketSource.account_id == account_id) + if region: + session_result = session_result.filter(S3BucketSource.region == region) + if condition: + return query_with_condition(session_result, condition) + return session_result + +def list_rds_resources(account_id, region, condition): + session_result = get_session().query(RdsInstanceSource) + if account_id: + session_result = session_result.filter(RdsInstanceSource.account_id == account_id) + if region: + session_result = session_result.filter(RdsInstanceSource.region == region) + if condition: + return query_with_condition(session_result, condition) + return session_result + +def list_glue_resources(account_id, region, condition): + session_result = get_session().query(SourceGlueDatabase) + if account_id: + session_result = session_result.filter(SourceGlueDatabase.account_id == account_id) + if region: + session_result = session_result.filter(SourceGlueDatabase.region == region) + if condition: + return query_with_condition(session_result, condition) + return session_result + +def list_jdbc_resources_by_provider(provider_id: int, account_id, region, condition): + session_result = get_session().query(JDBCInstanceSource).filter(JDBCInstanceSource.account_provider_id == provider_id) + if account_id: + session_result = session_result.filter(JDBCInstanceSource.account_id == account_id) + if region: + session_result = session_result.filter(JDBCInstanceSource.region == region) + if condition: + return query_with_condition(session_result, condition) + return session_result + +# ["account_id", "region", "bucket_name", "crawler_status", "last_updated_at", "last_updated_by"] +def get_datasource_from_s3(): + return get_session().query(S3BucketSource.account_id, + S3BucketSource.region, + S3BucketSource.bucket_name, + S3BucketSource.glue_state, + S3BucketSource.modify_time, + S3BucketSource.modify_by + ).all() + +# ["account_id", "region", "instance_name", "engine_type", "location", "crawler_status", "last_updated_at", "last_updated_by"] +def get_datasource_from_rds(): + return get_session().query(RdsInstanceSource.account_id, + RdsInstanceSource.region, + RdsInstanceSource.instance_id, + RdsInstanceSource.engine, + RdsInstanceSource.address, + RdsInstanceSource.glue_state, + RdsInstanceSource.modify_time, + RdsInstanceSource.modify_by + ).all() + +# ["account_id", "region", "database_name", "description", "location", "crawler_status", "last_updated_at", "last_updated_by"] +def get_datasource_from_glue(): + return get_session().query(SourceGlueDatabase.account_id, + SourceGlueDatabase.region, + SourceGlueDatabase.glue_database_name, + SourceGlueDatabase.glue_database_description, + SourceGlueDatabase.glue_database_location_uri, + SourceGlueDatabase.glue_state, + SourceGlueDatabase.modify_time, + SourceGlueDatabase.modify_by).all() + +# ["type", "account_id", "region", "instance_name", "description", "location", "crawler_status", "last_updated_at", "last_updated_by"] +def get_datasource_from_jdbc(): + return get_session().query(JDBCInstanceSource.account_provider_id, + JDBCInstanceSource.account_id, + JDBCInstanceSource.region, + JDBCInstanceSource.instance_id, + JDBCInstanceSource.description, + JDBCInstanceSource.jdbc_connection_url, + JDBCInstanceSource.glue_state, + JDBCInstanceSource.modify_time, + JDBCInstanceSource.modify_by).all() diff --git a/source/constructs/api/data_source/glue_database_detector.py b/source/constructs/api/data_source/glue_database_detector.py index 05f3d13e..04dab1c3 100644 --- a/source/constructs/api/data_source/glue_database_detector.py +++ b/source/constructs/api/data_source/glue_database_detector.py @@ -1,21 +1,15 @@ -import os - import boto3 -import logging - from common.constant import const from common.enum import ConnectionState, DatabaseType, Provider from db.database import get_session -from db.models_data_source import DetectionHistory, RdsInstanceSource, Account +from db.models_data_source import DetectionHistory from . import crud, schemas -from catalog.service import delete_catalog_by_database_region from sqlalchemy.orm import Session import asyncio +from common.reference_parameter import logger, admin_region sts_client = boto3.client('sts') -admin_account_region = boto3.session.Session().region_name -logger = logging.getLogger() -logger.setLevel(logging.INFO) + async def detect_glue_database_connection(session: Session, aws_account_id: str): iam_role_name = crud.get_iam_role(aws_account_id) @@ -43,7 +37,7 @@ async def detect_glue_database_connection(session: Session, aws_account_id: str) ) # glue_database_list = client.get_databases()['DatabaseList'] glue_database_list = get_all_databases(client) - db_glue_list = crud.list_glue_database_ar(account_id=aws_account_id, region=admin_account_region) + db_glue_list = crud.list_glue_database_ar(account_id=aws_account_id, region=admin_region) for item in db_glue_list: if not item.glue_database_name.upper().startswith(DatabaseType.JDBC.value): db_database_name_list.append(item.glue_database_name) @@ -67,7 +61,7 @@ async def detect_glue_database_connection(session: Session, aws_account_id: str) if item not in glue_database_name_list: refresh_list.append(item) crud.delete_not_exist_glue_database(refresh_list) - crud.update_glue_database_count(account=aws_account_id, region=admin_account_region) + crud.update_glue_database_count(account=aws_account_id, region=admin_region) def get_all_databases(glue_client): diff --git a/source/constructs/api/data_source/jdbc_database.py b/source/constructs/api/data_source/jdbc_database.py index b397b04f..ef90573d 100644 --- a/source/constructs/api/data_source/jdbc_database.py +++ b/source/constructs/api/data_source/jdbc_database.py @@ -15,11 +15,12 @@ def list_databases(self) -> list[str]: class MySQLDatabase(JdbcDatabase): ignored_databases = ['information_schema', 'innodb', 'mysql', 'performance_schema', 'sys', 'tmp'] - def __init__(self, host, port, user, password): + def __init__(self, host, port, user, password, ssl_verify_cert=False): self.host = host self.port = port self.user = user self.password = password + self.ssl_verify_cert = ssl_verify_cert def list_databases(self): databases = [] @@ -27,11 +28,13 @@ def list_databases(self): db = pymysql.connect(host=self.host, port=self.port, user=self.user, - password=self.password) + password=self.password, + ssl_verify_cert=self.ssl_verify_cert, + ) except Exception as e: logger.info(e) raise BizException(MessageEnum.SOURCE_JDBC_LIST_DATABASES_FAILED.get_code(), - str(e.args[1]) if e.args else traceback.format_exc()) + str(e.args[len(e.args) - 1]) if e.args else traceback.format_exc()) try: cursor = db.cursor() diff --git a/source/constructs/api/data_source/jdbc_detector.py b/source/constructs/api/data_source/jdbc_detector.py index 943d8a40..edff69cf 100644 --- a/source/constructs/api/data_source/jdbc_detector.py +++ b/source/constructs/api/data_source/jdbc_detector.py @@ -1,29 +1,18 @@ -import os - import boto3 -import logging - -from common.constant import const from common.enum import Provider, MessageEnum from db.database import get_session from db.models_data_source import DetectionHistory from common.abilities import convert_provider_id_2_name from . import crud -from catalog.service import delete_catalog_by_database_region from sqlalchemy.orm import Session import asyncio from db.models_data_source import JDBCInstanceSource from botocore.exceptions import ClientError from common.exception_handler import BizException +from common.reference_parameter import logger, admin_account_id, admin_region sts_client = boto3.client('sts') -admin_account_region = boto3.session.Session().region_name -caller_identity = sts_client.get_caller_identity() -partition = caller_identity['Arn'].split(':')[1] -admin_account_id = caller_identity.get('Account') -logger = logging.getLogger() -logger.setLevel(logging.INFO) async def detect_jdbc_connection(provider_id: int, account_id: str, session: Session): not_exist_connections = [] @@ -34,7 +23,7 @@ async def detect_jdbc_connection(provider_id: int, account_id: str, session: Ses else: history = DetectionHistory(provider=convert_provider_id_2_name(provider_id), account_id=account_id, source_type='jdbc', state=0) iam_role_name = crud.get_iam_role(admin_account_id) - regions = [admin_account_region] + regions = [admin_region] session.add(history) session.commit() assumed_role_object = sts_client.assume_role( @@ -65,7 +54,7 @@ async def detect_jdbc_connection(provider_id: int, account_id: str, session: Ses not_exist_connections.append(item.id) # delete not existed jdbc crud.delete_jdbc_connection_by_accounts(not_exist_connections) - region = admin_account_region if provider_id == Provider.AWS_CLOUD.value else None + region = admin_region if provider_id == Provider.AWS_CLOUD.value else None crud.update_jdbc_instance_count(provider=provider_id, account=account_id, region=region) async def detect_multiple_account_in_async(provider_id, accounts): diff --git a/source/constructs/api/data_source/jdbc_schema.py b/source/constructs/api/data_source/jdbc_schema.py new file mode 100644 index 00000000..caa92f2d --- /dev/null +++ b/source/constructs/api/data_source/jdbc_schema.py @@ -0,0 +1,127 @@ +import boto3 +import json +import traceback +from common.exception_handler import BizException +from common.enum import MessageEnum, Provider +from common.constant import const +from common.reference_parameter import logger, admin_account_id, admin_region, partition +from . import jdbc_database, crud +from .schemas import JdbcSource, JDBCInstanceSourceBase + +sts = boto3.client('sts') + + +def list_jdbc_databases(source: JdbcSource) -> list[str]: + url_arr = source.connection_url.split(":") + if len(url_arr) != 4: + raise BizException(MessageEnum.SOURCE_JDBC_URL_FORMAT_ERROR.get_code(), MessageEnum.SOURCE_JDBC_URL_FORMAT_ERROR.get_msg()) + if url_arr[1] != "mysql": + raise BizException(MessageEnum.SOURCE_JDBC_LIST_DATABASES_NOT_SUPPORTED.get_code(), MessageEnum.SOURCE_JDBC_LIST_DATABASES_NOT_SUPPORTED.get_msg()) + host = url_arr[2][2:] + port = int(url_arr[3].split("/")[0]) + user = source.username + password = source.password + ssl_verify_cert = True if source.ssl_verify_cert else False + if source.secret_id: + secrets_client = boto3.client('secretsmanager') + secret_response = secrets_client.get_secret_value(SecretId=source.secret_id) + secrets = json.loads(secret_response['SecretString']) + user = secrets['username'] + password = secrets['password'] + mysql_database = jdbc_database.MySQLDatabase(host, port, user, password, ssl_verify_cert) + databases = mysql_database.list_databases() + logger.info(databases) + return databases + + +def get_schema_by_snapshot(provider_id: int, account_id: str, region: str, instance: str): + res = crud.get_schema_by_snapshot(provider_id, account_id, region, instance) + return res[0].split('\n') if res else None, res[1] if res else None + + +def get_schema_by_real_time(provider_id: int, account_id: str, region: str, instance: str, db_info: bool = False): + db, subnet_id = None, None + assume_account, assume_region = __get_admin_info(JDBCInstanceSourceBase(account_provider_id=provider_id, account_id=account_id, instance_id=instance, region=region)) + connection_rds = crud.get_connection_by_instance(provider_id, account_id, region, instance) + glue = __get_glue_client(assume_account, assume_region) + connection = glue.get_connection(Name=connection_rds[0]).get('Connection', {}) + subnet_id = connection.get('PhysicalConnectionRequirements', {}).get('SubnetId') + if db_info: + connection_properties = connection.get("ConnectionProperties", {}) + jdbc_source = JdbcSource(username=connection_properties.get("USERNAME"), + password=connection_properties.get("PASSWORD"), + secret_id=connection_properties.get("SECRET_ID"), + connection_url=connection_properties.get("JDBC_CONNECTION_URL") + ) + try: + db = list_jdbc_databases(jdbc_source) + except Exception as e: + logger.info(e) + return db, subnet_id + + +def sync_schema_by_job(provider_id: int, account_id: str, region: str, instance: str, schemas: list): + jdbc_targets = [] + # Query Info + info = crud.get_crawler_glue_db_by_instance(provider_id, account_id, region, instance) + logger.info(f"info:{info}") + if not info: + return + for db_name in schemas: + trimmed_db_name = db_name.strip() + if trimmed_db_name: + jdbc_targets.append({ + 'ConnectionName': info[2], + 'Path': f"{trimmed_db_name}/%" + }) + # Update Crawler + assume_account, assume_region = __get_admin_info(JDBCInstanceSourceBase(account_provider_id=provider_id, account_id=account_id, instance_id=instance, region=region)) + crawler_role_arn = __gen_role_arn(account_id=assume_account, + region=assume_region, + role_name='GlueDetectionJobRole') + try: + __get_glue_client(assume_account, assume_region).update_crawler( + Name=info[0], + Role=crawler_role_arn, + DatabaseName=info[1], + Targets={ + 'JdbcTargets': jdbc_targets, + }, + SchemaChangePolicy={ + 'UpdateBehavior': 'UPDATE_IN_DATABASE', + 'DeleteBehavior': 'DELETE_FROM_DATABASE' + } + ) + except Exception as e: + logger.error(traceback.format_exc()) + raise BizException(MessageEnum.BIZ_UNKNOWN_ERR.get_code(), + MessageEnum.BIZ_UNKNOWN_ERR.get_msg()) + # Update RDS + crud.update_schema_by_account(provider_id, account_id, instance, region, "\n".join(schemas)) + + +def __get_admin_info(jdbc): + account_id = jdbc.account_id if jdbc.account_provider_id == Provider.AWS_CLOUD.value else admin_account_id + region = jdbc.region if jdbc.account_provider_id == Provider.AWS_CLOUD.value else admin_region + return account_id, region + + +def __get_glue_client(account, region): + iam_role_name = crud.get_iam_role(account) + logger.info(f"iam_role_name:{iam_role_name}") + assumed_role = sts.assume_role( + RoleArn=f"{iam_role_name}", + RoleSessionName="glue-connection" + ) + credentials = assumed_role['Credentials'] + glue = boto3.client('glue', + aws_access_key_id=credentials['AccessKeyId'], + aws_secret_access_key=credentials['SecretAccessKey'], + aws_session_token=credentials['SessionToken'], + region_name=region + ) + return glue + + +def __gen_role_arn(account_id: str, region: str, role_name: str): + return f'arn:{partition}:iam::{account_id}:role/{const.SOLUTION_NAME}{role_name}-{region}' diff --git a/source/constructs/api/data_source/main.py b/source/constructs/api/data_source/main.py index 135ec4d1..15f58961 100644 --- a/source/constructs/api/data_source/main.py +++ b/source/constructs/api/data_source/main.py @@ -1,11 +1,14 @@ -from fastapi import APIRouter +from io import BytesIO +from typing import List +from fastapi import APIRouter, File, UploadFile from fastapi_pagination import Page, Params from fastapi_pagination.ext.sqlalchemy import paginate from common.query_condition import QueryCondition from common.request_wrapper import inject_session from common.response_wrapper import BaseResponse -from . import crud, schemas, service +# from .resource_list import list_resources_by_database_type +from . import crud, schemas, service, jdbc_schema, resource_list router = APIRouter(prefix="/data-source", tags=["data-source"]) @@ -81,15 +84,15 @@ def delete_s3_connection(s3: schemas.SourceDeteteS3Connection): s3.bucket ) -@router.post("/disconnect-delete-catalog-jdbc", response_model=BaseResponse) -@inject_session -def disconnect_and_delete_catalog_jdbc_connection(jdbc: schemas.SourceDeteteJDBCConnection): - return service.delete_jdbc_connection( - int(jdbc.account_provider), - jdbc.account_id, - jdbc.region, - jdbc.instance - ) +# @router.post("/disconnect-delete-catalog-jdbc", response_model=BaseResponse) +# @inject_session +# def disconnect_and_delete_catalog_jdbc_connection(jdbc: schemas.SourceDeteteJDBCConnection): +# return service.delete_jdbc_connection( +# int(jdbc.account_provider), +# jdbc.account_id, +# jdbc.region, +# jdbc.instance +# ) @router.post("/hide-s3", response_model=BaseResponse) @inject_session @@ -190,8 +193,6 @@ def hide_glue_database(glueDatabase: schemas.SourceDeteteGlueDatabase): glueDatabase.name ) - - @router.post("/sync-glue-database", response_model=BaseResponse) @inject_session def sync_glue_database(glueDatabase: schemas.SourceGlueDatabaseBase): @@ -203,34 +204,24 @@ def sync_glue_database(glueDatabase: schemas.SourceGlueDatabaseBase): @router.post("/delete-jdbc", response_model=BaseResponse) @inject_session -def delete_jdbc_connection(jdbc: schemas.SourceDeteteJDBCConnection): - return service.delete_jdbc_connection( - int(jdbc.account_provider), - jdbc.account_id, - jdbc.region, - jdbc.instance - ) - -@router.post("/delete-catalog-jdbc", response_model=BaseResponse) -@inject_session -def delete_catalog_jdbc_connection(jdbc: schemas.SourceDeteteJDBCConnection): - return service.delete_jdbc_connection( +def delete_jdbc_connections(jdbc: schemas.SourceDeteteJDBCConnection): + return service.delete_jdbc_connections( int(jdbc.account_provider), jdbc.account_id, jdbc.region, - jdbc.instance, - delete_catalog_only=True + jdbc.instances ) -@router.post("/hide-jdbc", response_model=BaseResponse) -@inject_session -def hide_jdbc_connection(jdbc: schemas.SourceDeteteJDBCConnection): - return service.hide_jdbc_connection( - int(jdbc.account_provider), - jdbc.account_id, - jdbc.region, - jdbc.instance - ) +# @router.post("/delete-catalog-jdbc", response_model=BaseResponse) +# @inject_session +# def delete_catalog_jdbc_connection(jdbc: schemas.SourceDeteteJDBCConnection): +# return service.delete_jdbc_connection( +# int(jdbc.account_provider), +# jdbc.account_id, +# jdbc.region, +# jdbc.instance, +# delete_catalog_only=True +# ) @router.post("/sync-jdbc", response_model=BaseResponse) @inject_session @@ -352,7 +343,6 @@ def test_jdbc_conn(jdbc_conn_param: schemas.JDBCInstanceSourceBase): def get_data_location_list(): return service.list_data_location() - @router.get("/query-regions-by-provider", response_model=BaseResponse) @inject_session def query_regions_by_provider(provider_id: str): @@ -379,8 +369,54 @@ def list_buckets(account: schemas.AccountInfo): def query_connection_detail(account: schemas.JDBCInstanceSourceBase): return service.query_connection_detail(account) - @router.post("/jdbc-databases", response_model=BaseResponse[list[str]]) @inject_session def list_jdbc_databases(source: schemas.JdbcSource): - return service.list_jdbc_databases(source) + return jdbc_schema.list_jdbc_databases(source) + +@router.post("/batch-create", response_model=BaseResponse) +@inject_session +def batch_create(files: List[UploadFile] = File(...)): + return service.batch_create(files[0]) + +@router.post("/query-batch-status", response_model=BaseResponse) +@inject_session +def query_batch_status(batch: str): + return service.query_batch_status(batch) + +@router.post("/download-batch-file", response_model=BaseResponse) +@inject_session +def download_batch_file(filename: str): + return service.download_batch_file(filename) + +@router.post("/list-resources-by-database-type", response_model=BaseResponse) +@inject_session +def list_resources_by_database_type(database_type: str, condition: QueryCondition, account_id: str = None, region: str = None): + return paginate(resource_list.list_resources_by_database_type(database_type=database_type, + account_id=account_id, + region=region, + condition=condition), Params( + size=condition.size, + page=condition.page, + )) + +@router.post("/export-datasource", response_model=BaseResponse) +@inject_session +def export_datasource(key: str): + return service.export_datasource(key) + +@router.post("/delete-report", response_model=BaseResponse) +@inject_session +def delete_report(key: str): + return service.delete_report(key) + + +@router.post("/batch-delete", response_model=BaseResponse) +@inject_session +def batch_delete_resource(account, datasource_list): + return service.batch_delete_resource(account, datasource_list) + +# @router.post("/batch-sync-jdbc", response_model=BaseResponse) +# @inject_session +# def batch_sync_jdbc(connection_list: [schemas.JDBCInstanceSourceBase]): +# return service.batch_sync_jdbc(connection_list) diff --git a/source/constructs/api/data_source/rds_detector.py b/source/constructs/api/data_source/rds_detector.py index 847f3fcd..54e2fe12 100644 --- a/source/constructs/api/data_source/rds_detector.py +++ b/source/constructs/api/data_source/rds_detector.py @@ -1,8 +1,4 @@ -import os - import boto3 -import logging - from common.constant import const from common.enum import ConnectionState, DatabaseType, Provider from db.database import get_session @@ -12,11 +8,10 @@ from catalog.service import delete_catalog_by_database_region from sqlalchemy.orm import Session import asyncio +from common.reference_parameter import logger, admin_region sts_client = boto3.client('sts') -admin_account_region = boto3.session.Session().region_name -logger = logging.getLogger() -logger.setLevel(logging.INFO) + async def detect_rds_data_source(session: Session, aws_account_id: str): iam_role_name = crud.get_iam_role(aws_account_id) @@ -110,7 +105,7 @@ async def detect_rds_data_source(session: Session, aws_account_id: str): Account.account_id == aws_account_id, Account.region == region).first() # TODO support multiple regions - crud.update_rds_instance_count(account=aws_account_id, region=admin_account_region) + crud.update_rds_instance_count(account=aws_account_id, region=admin_region) async def detect_multiple_account_in_async(accounts): diff --git a/source/constructs/api/data_source/resource_list.py b/source/constructs/api/data_source/resource_list.py new file mode 100644 index 00000000..12c45035 --- /dev/null +++ b/source/constructs/api/data_source/resource_list.py @@ -0,0 +1,15 @@ +from common.enum import DatabaseType +from common.abilities import convert_database_type_2_provider +from common.query_condition import QueryCondition +from . import crud + + +def list_resources_by_database_type(database_type: str, account_id: str = None, region: str = None, condition: QueryCondition = None): + if database_type == DatabaseType.S3.value: + return crud.list_s3_resources(account_id, region, condition) + elif database_type == DatabaseType.RDS.value: + return crud.list_rds_resources(account_id, region, condition) + elif database_type == DatabaseType.GLUE.value: + return crud.list_glue_resources(account_id, region, condition) + else: + return crud.list_jdbc_resources_by_provider(convert_database_type_2_provider(database_type), account_id, region, condition) diff --git a/source/constructs/api/data_source/s3_detector.py b/source/constructs/api/data_source/s3_detector.py index 47fb4e20..d50631e0 100644 --- a/source/constructs/api/data_source/s3_detector.py +++ b/source/constructs/api/data_source/s3_detector.py @@ -1,10 +1,5 @@ -import os - import boto3 -import logging from sqlalchemy import desc - -from common.constant import const from common.enum import ConnectionState, DatabaseType from db.database import get_session from db.models_data_source import S3BucketSource, DetectionHistory @@ -13,12 +8,10 @@ from catalog.service import delete_catalog_by_database_region import asyncio from sqlalchemy.orm import Session +from common.reference_parameter import logger, admin_region -admin_account_region = boto3.session.Session().region_name sts_client = boto3.client('sts') -logger = logging.getLogger() -logger.setLevel(logging.INFO) async def detect_s3_data_source(session: Session, aws_account_id: str): iam_role_name = crud.get_iam_role(aws_account_id) @@ -110,7 +103,7 @@ async def detect_s3_data_source(session: Session, aws_account_id: str): crud.delete_s3_bucket_source_by_name(aws_account_id, deleted_s3_region, deleted_s3_bucket_source) # TODO support multiple regions - crud.update_s3_bucket_count(account=aws_account_id, region=admin_account_region) + crud.update_s3_bucket_count(account=aws_account_id, region=admin_region) async def detect_multiple_account_in_async(accounts): session = get_session() diff --git a/source/constructs/api/data_source/schemas.py b/source/constructs/api/data_source/schemas.py index bf32eef8..11026660 100644 --- a/source/constructs/api/data_source/schemas.py +++ b/source/constructs/api/data_source/schemas.py @@ -304,7 +304,7 @@ class SourceDeteteJDBCConnection(BaseModel): account_provider: int account_id: str region: str - instance: str + instances: list[str] class SourceDeteteS3Connection(BaseModel): account_id: str @@ -404,3 +404,4 @@ class JdbcSource(BaseModel): username: Optional[str] password: Optional[str] secret_id: Optional[str] + ssl_verify_cert: Optional[bool] diff --git a/source/constructs/api/data_source/service.py b/source/constructs/api/data_source/service.py index 79de4a87..00fa2102 100644 --- a/source/constructs/api/data_source/service.py +++ b/source/constructs/api/data_source/service.py @@ -1,35 +1,43 @@ +import asyncio +from datetime import datetime +from io import BytesIO import json import os +import random import re +import tempfile import time import traceback from time import sleep +from typing import List import boto3 +from fastapi import File, UploadFile +import openpyxl import pymysql from botocore.exceptions import ClientError -from catalog.service import delete_catalog_by_account_region as delete_catalog_by_account +from catalog.service import delete_catalog_by_account_region as delete_catalog_by_account, delete_catalog_by_database_region_batch from catalog.service import delete_catalog_by_database_region as delete_catalog_by_database_region -from common.abilities import (convert_provider_id_2_database_type, convert_provider_id_2_name, query_all_vpc) +from common.abilities import (convert_provider_id_2_database_type, convert_provider_id_2_name, insert_error_msg_2_cells, insert_success_2_cells, query_all_vpc) from common.constant import const from common.enum import (MessageEnum, ConnectionState, Provider, DataSourceType, DatabaseType, - JDBCCreateType) + JDBCCreateType, ProviderName) from common.exception_handler import BizException from common.query_condition import QueryCondition -from db.models_data_source import (Account, - JDBCInstanceSource, - SourceRegion) +from db.models_data_source import (Account) from discovery_job.service import can_delete_database as can_delete_job_database from discovery_job.service import delete_account as delete_job_by_account from discovery_job.service import delete_database as delete_job_database -from . import s3_detector, rds_detector, glue_database_detector, jdbc_detector, crud, jdbc_database +from common.concurrent_upload2s3 import concurrent_upload +from .jdbc_schema import list_jdbc_databases +from . import s3_detector, rds_detector, glue_database_detector, jdbc_detector, crud from .schemas import (AccountInfo, AdminAccountInfo, - JDBCInstanceSource, JDBCInstanceSourceUpdate, + JDBCInstanceSource, JDBCInstanceSourceUpdate, JdbcSource, ProviderResourceFullInfo, SourceNewAccount, SourceRegion, SourceResourceBase, SourceCoverage, @@ -37,8 +45,7 @@ JDBCInstanceSourceUpdateBase, DataLocationInfo, JDBCInstanceSourceBase, - JDBCInstanceSourceFullInfo, - JdbcSource) + JDBCInstanceSourceFullInfo) from common.reference_parameter import logger, admin_account_id, admin_region, partition, admin_bucket_name SLEEP_TIME = 5 @@ -57,6 +64,8 @@ sts = boto3.client('sts') """ :type : pyboto3.sts """ +__s3_client = boto3.client('s3') + _jdbc_url_patterns = [ r'jdbc:redshift://[\w.-]+:\d+/([\w-]+)', r'jdbc:redshift://[\w.-]+:\d+', @@ -67,9 +76,9 @@ r'jdbc:oracle:thin://@[\w.-]+:\d+/([\w-]+)', r'jdbc:oracle:thin://@[\w.-]+:\d+:\w+', r'jdbc:sqlserver://[\w.-]+:\d+;databaseName=([\w-]+)', - r'jdbc:sqlserver://[\w.-]+:\d+;database=([\w-]+)' - ] + r'jdbc:sqlserver://[\w.-]+:\d+;database=([\w-]+)'] +__s3_client = boto3.client('s3') def build_s3_targets(bucket, credentials, region, is_init): s3 = boto3.client('s3', @@ -116,7 +125,6 @@ def build_s3_targets(bucket, credentials, region, is_init): logger.info(s3_targets) return s3_targets - def sync_s3_connection(account: str, region: str, bucket: str): glue_connection_name = f"{const.SOLUTION_NAME}-{DatabaseType.S3.value}-{bucket}" glue_database_name = f"{const.SOLUTION_NAME}-{DatabaseType.S3.value}-{bucket}" @@ -220,11 +228,6 @@ def sync_s3_connection(account: str, region: str, bucket: str): except Exception as e: logger.info("update_crawler s3 error") logger.info(str(e)) - - # data source create crawler, job to run crawler - # response = glue.start_crawler( - # Name=crawler_name - # ) logger.info(response) except Exception as e: response = glue.create_crawler( @@ -240,11 +243,6 @@ def sync_s3_connection(account: str, region: str, bucket: str): }, ) logger.info(response) - # data source create crawler, job to run crawler - # response = glue.start_crawler( - # Name=crawler_name - # ) - # logger.info(response) crud.create_s3_connection(account, region, bucket, glue_connection_name, glue_database_name, crawler_name) except Exception as err: @@ -377,8 +375,7 @@ def sync_glue_database(account_id, region, glue_database_name): def sync_jdbc_connection(jdbc: JDBCInstanceSourceBase): - account_id = jdbc.account_id if jdbc.account_provider_id == Provider.AWS_CLOUD.value else admin_account_id - region = jdbc.region if jdbc.account_provider_id == Provider.AWS_CLOUD.value else admin_region + account_id, region = __get_admin_info(jdbc) ec2_client, credentials = __ec2(account=account_id, region=region) glue_client = __glue(account=account_id, region=region) lakeformation_client = __lakeformation(account=account_id, region=region) @@ -399,8 +396,30 @@ def sync_jdbc_connection(jdbc: JDBCInstanceSourceBase): logger.debug(f"conn_response type is:{type(conn_response)}") logger.debug(f"conn_response is:{conn_response}") + if conn_response.get('ConnectionProperties'): + username = conn_response.get('ConnectionProperties', {}).get('USERNAME') + password = conn_response.get('ConnectionProperties', {}).get('PASSWORD') + secret = conn_response.get('ConnectionProperties', {}).get("SECRET_ID"), + url = conn_response.get('ConnectionProperties', {}).get('JDBC_CONNECTION_URL'), + jdbc_instance = JDBCInstanceSource(instance_id=jdbc.instance_id, + account_provider_id=jdbc.account_provider_id, + account_id=jdbc.account_id, + region=jdbc.region, + jdbc_connection_url=url[0], + master_username=username, + password=password, + secret=secret[0]) + # jdbc_instance.jdbc_connection_url = url # condition_check(ec2_client, credentials, source.glue_state, conn_response['PhysicalConnectionRequirements']) - sync(glue_client, lakeformation_client, credentials, crawler_role_arn, jdbc, conn_response['ConnectionProperties']['JDBC_CONNECTION_URL'], source.jdbc_connection_schema) + sync(glue_client, + lakeformation_client, + credentials, + crawler_role_arn, + jdbc_instance, + source.jdbc_connection_schema) + else: + raise BizException(MessageEnum.BIZ_UNKNOWN_ERR.get_code(), + MessageEnum.BIZ_UNKNOWN_ERR.get_msg()) def condition_check(ec2_client, credentials, state, connection: dict): @@ -489,16 +508,15 @@ def condition_check(ec2_client, credentials, state, connection: dict): MessageEnum.SOURCE_AVAILABILITY_ZONE_NOT_EXISTS.get_msg()) -def sync(glue, lakeformation, credentials, crawler_role_arn, jdbc: JDBCInstanceSourceBase, url: str, schemas: str): +def sync(glue, lakeformation, credentials, crawler_role_arn, jdbc: JDBCInstanceSource, schemas: str): jdbc_targets = [] - database_type = convert_provider_id_2_database_type(jdbc.account_provider_id) - glue_database_name = f"{const.SOLUTION_NAME}-{database_type}-{jdbc.instance_id}" - crawler_name = f"{const.SOLUTION_NAME}-{database_type}-{jdbc.instance_id}" + _, glue_database_name, crawler_name = __gen_resources_name(jdbc) state, glue_connection_name = crud.get_jdbc_connection_glue_info(jdbc.account_provider_id, jdbc.account_id, jdbc.region, jdbc.instance_id) if state == ConnectionState.CRAWLING.value: raise BizException(MessageEnum.SOURCE_CONNECTION_CRAWLING.get_code(), MessageEnum.SOURCE_CONNECTION_CRAWLING.get_msg()) - db_names = get_db_names(url, schemas) + jdbc_source = JdbcSource(connection_url=jdbc.jdbc_connection_url, username=jdbc.master_username, password=jdbc.password, secret_id=jdbc.secret) + db_names = get_db_names_4_jdbc(jdbc_source, schemas) try: for db_name in db_names: trimmed_db_name = db_name.strip() @@ -687,6 +705,7 @@ def before_delete_jdbc_connection(provider_id, account, region, instance_id, dat MessageEnum.DISCOVERY_JOB_CAN_NOT_DELETE_DATABASE.get_msg()) else: logger.info(f"delete jdbc connection: {account},{region},{database_type},{jdbc_instance.instance_id}") + return jdbc_instance.glue_crawler def gen_assume_account(provider_id, account, region): account = account if provider_id == Provider.AWS_CLOUD.value else admin_account_id @@ -730,9 +749,14 @@ def delete_glue_database(provider_id: int, account: str, region: str, name: str) return True -def delete_jdbc_connection(provider_id: int, account: str, region: str, instance_id: str, delete_catalog_only=False): +async def __delete_jdbc_connection(provider_id: int, account: str, region: str, instance_id: str, delete_catalog_only=False): database_type = convert_provider_id_2_database_type(provider_id) - before_delete_jdbc_connection(provider_id, account, region, instance_id, database_type) + try: + before_delete_jdbc_connection(provider_id, account, region, instance_id, database_type) + except BizException as be: + return instance_id, be.__msg__() + except Exception as e: + return instance_id, str(e) assume_account, assume_region = gen_assume_account(provider_id, account, region) err = [] # 1/3 delete job database @@ -741,17 +765,19 @@ def delete_jdbc_connection(provider_id: int, account: str, region: str, instance logger.info('delete_job_database start') delete_job_database(account_id=account, region=region, database_type=database_type, database_name=instance_id) logger.info('delete_job_database end') + except BizException as be: + return instance_id, be.__msg__() except Exception as e: - logger.error(traceback.format_exc()) - err.append(str(e)) + return instance_id, str(e) # 2/3 delete catalog try: logger.info('delete_catalog_by_database_region start') - delete_catalog_by_database_region(database=instance_id, region=region, type=database_type) + delete_catalog_by_database_region_batch(database=instance_id, region=region, type=database_type) logger.info('delete_catalog_by_database_region end') + except BizException as be: + return instance_id, be.__msg__() except Exception as e: - logger.error(traceback.format_exc()) - err.append(str(e)) + return instance_id, str(e) if not delete_catalog_only: # 3/3 delete source @@ -762,53 +788,50 @@ def delete_jdbc_connection(provider_id: int, account: str, region: str, instance logger.info(f'delete_crawler start:{assume_account, jdbc_conn.glue_crawler}') glue.delete_crawler(Name=jdbc_conn.glue_crawler) logger.info(f'delete_crawler end:{jdbc_conn.glue_crawler}') + except BizException as be: + return instance_id, be.__msg__() except Exception as e: - logger.error(traceback.format_exc()) - err.append(str(e)) + return instance_id, str(e) if jdbc_conn.glue_database: try: glue.delete_database(Name=jdbc_conn.glue_database) + except BizException as be: + return instance_id, be.__msg__() except Exception as e: - logger.error(traceback.format_exc()) - err.append(str(e)) + return instance_id, str(e) if jdbc_conn.glue_connection: try: glue.delete_connection( CatalogId=assume_account, ConnectionName=jdbc_conn.glue_connection ) + except BizException as be: + return instance_id, be.__msg__() except Exception as e: - logger.error(traceback.format_exc()) - err.append(str(e)) + return instance_id, str(e) crud.delete_jdbc_connection(provider_id, account, region, instance_id) try: crud.update_jdbc_instance_count(provider_id, account, region) + except BizException as be: + return instance_id, be.__msg__() except Exception as e: - logger.error(traceback.format_exc()) - err.append(str(e)) + return instance_id, str(e) if not err: logger.error(err) # raise BizException(MessageEnum.SOURCE_S3_CONNECTION_DELETE_ERROR.get_code(), err) - return True - -def hide_jdbc_connection(provider_id: int, account: str, region: str, instance_id: str): - database_type = convert_provider_id_2_database_type(provider_id) - before_delete_jdbc_connection(provider_id, account, region, instance_id, database_type) - - err = [] - crud.hide_jdbc_connection(provider_id, account, region, instance_id) - try: - crud.update_jdbc_instance_count(provider_id, account, region) - except Exception as e: - logger.error(traceback.format_exc()) - err.append(str(e)) + return instance_id, const.EMPTY_STR - if not err: - logger.error(err) +async def __delete_jdbc_connections(provider_id: int, account: str, region: str, instances: list[str]): + tasks = [] + for instance in instances: + task = asyncio.create_task(__delete_jdbc_connection(provider_id, account, region, instance)) + tasks.append(task) + return await asyncio.gather(*tasks) - return True +def delete_jdbc_connections(provider_id: int, account: str, region: str, instances: list[str]): + return asyncio.run(__delete_jdbc_connections(provider_id, account, region, instances)) def gen_credentials(account: str): try: @@ -1091,10 +1114,10 @@ def sync_rds_connection(account: str, region: str, instance_name: str, rds_user= except Exception as e: logger.info("update_crawler error") logger.info(str(e)) - st_cr_response = glue.start_crawler( - Name=crawler_name - ) - logger.info(st_cr_response) + # st_cr_response = glue.start_crawler( + # Name=crawler_name + # ) + # logger.info(st_cr_response) except Exception as e: logger.info("sync_rds_connection get_crawler and create:") logger.info(str(e)) @@ -1111,10 +1134,10 @@ def sync_rds_connection(account: str, region: str, instance_name: str, rds_user= }, ) logger.info(response) - start_response = glue.start_crawler( - Name=crawler_name - ) - logger.info(start_response) + # start_response = glue.start_crawler( + # Name=crawler_name + # ) + # logger.info(start_response) crud.create_rds_connection(account, region, instance_name, glue_connection_name, glue_database_name, None, crawler_name) else: @@ -1155,12 +1178,6 @@ def before_delete_rds_connection(account: str, region: str, instance: str): if rds_instance is None: raise BizException(MessageEnum.SOURCE_RDS_NO_INSTANCE.get_code(), MessageEnum.SOURCE_RDS_NO_INSTANCE.get_msg()) - # if rds_instance.glue_crawler is None: - # raise BizException(MessageEnum.SOURCE_RDS_NO_CRAWLER.get_code(), - # MessageEnum.SOURCE_RDS_NO_CRAWLER.get_msg()) - # if rds_instance.glue_database is None: - # raise BizException(MessageEnum.SOURCE_RDS_NO_DATABASE.get_code(), - # MessageEnum.SOURCE_RDS_NO_DATABASE.get_msg()) # crawler, if crawling try to stop and raise, if pending raise directly state = crud.get_rds_instance_source_glue_state(account, region, instance) if state == ConnectionState.PENDING.value: @@ -1371,13 +1388,7 @@ def refresh_third_data_source(provider_id: int, accounts: list[str], type: str): raise BizException(MessageEnum.SOURCE_REFRESH_FAILED.get_code(), MessageEnum.SOURCE_REFRESH_FAILED.get_msg()) try: - # if type == DataSourceType.jdbc.value: jdbc_detector.detect(provider_id, accounts) - # elif type == DataSourceType.all.value: - # s3_detector.detect(accounts) - # rds_detector.detect(accounts) - # glue_database_detector.detect(accounts) - # jdbc_detector.detect(accounts) except Exception as e: logger.error(traceback.format_exc()) raise BizException(MessageEnum.SOURCE_CONNECTION_FAILED.get_code(), str(e)) @@ -1624,7 +1635,7 @@ def get_secrets(provider: int, account: str, region: str): region_name=region_aws ) """ :type : pyboto3.secretsmanager """ - response = secretsmanager.list_secrets() + response = secretsmanager.list_secrets(MaxResults=100) secrets = [] for secret in response['SecretList']: secrets.append( @@ -1650,12 +1661,18 @@ def import_glue_database(glueDataBase: SourceGlueDatabaseBase): crud.import_glue_database(glueDataBase, response) def update_jdbc_conn(jdbc_conn: JDBCInstanceSource): - get_db_names(jdbc_conn.jdbc_connection_url, jdbc_conn.jdbc_connection_schema) - account_id = jdbc_conn.account_id if jdbc_conn.account_provider_id == Provider.AWS_CLOUD.value else admin_account_id - region = jdbc_conn.region if jdbc_conn.account_provider_id == Provider.AWS_CLOUD.value else admin_region - res: JDBCInstanceSourceFullInfo = crud.get_jdbc_instance_source_glue(jdbc_conn.account_provider_id, jdbc_conn.account_id, jdbc_conn.region, jdbc_conn.instance_id) + jdbc_source = JdbcSource(connection_url=jdbc_conn.jdbc_connection_url, + username=jdbc_conn.master_username, + password=jdbc_conn.password, + secret_id=jdbc_conn.secret) + dbnames = get_db_names_4_jdbc(jdbc_source, jdbc_conn.jdbc_connection_schema) + account_id, region = __get_admin_info(jdbc_conn) + res: JDBCInstanceSourceFullInfo = crud.get_jdbc_instance_source_glue(jdbc_conn.account_provider_id, + jdbc_conn.account_id, + jdbc_conn.region, + jdbc_conn.instance_id) check_connection(res, jdbc_conn, account_id, region) - update_connection(res, jdbc_conn, account_id, region) + update_connection(res, jdbc_conn, account_id, region, dbnames) def check_connection(res: JDBCInstanceSourceFullInfo, jdbc_instance: JDBCInstanceSource, assume_account, assume_role): if not res: @@ -1681,10 +1698,10 @@ def check_connection(res: JDBCInstanceSourceFullInfo, jdbc_instance: JDBCInstanc else: pass -def update_connection(res: JDBCInstanceSourceFullInfo, jdbc_instance: JDBCInstanceSourceUpdate, assume_account, assume_role): - # logger.info(f"source.glue_connection is: {source.glue_connection}") +def update_connection(res: JDBCInstanceSourceFullInfo, jdbc_instance: JDBCInstanceSourceUpdate, assume_account, assume_region, db_names): + jdbc_targets = __gen_jdbc_targets_from_db_names(res.glue_connection, db_names) connectionProperties_dict = gen_conn_properties(jdbc_instance) - response = __glue(account=assume_account, region=assume_role).update_connection( + __glue(account=assume_account, region=assume_region).update_connection( CatalogId=assume_account, Name=res.glue_connection, ConnectionInput={ @@ -1701,6 +1718,18 @@ def update_connection(res: JDBCInstanceSourceFullInfo, jdbc_instance: JDBCInstan } } ) + crawler_role_arn = __gen_role_arn(account_id=assume_account, + region=assume_region, + role_name='GlueDetectionJobRole') + # Update Crawler + __update_crawler(res.account_provider_id, + res.account_id, + res.instance_id, + res.region, + jdbc_targets, + res.glue_crawler, + res.glue_database, + crawler_role_arn) crud.update_jdbc_connection_full(jdbc_instance) def __validate_jdbc_url(url: str): @@ -1708,26 +1737,34 @@ def __validate_jdbc_url(url: str): if re.match(pattern, url): return True - def add_jdbc_conn(jdbcConn: JDBCInstanceSource): - get_db_names(jdbcConn.jdbc_connection_url, jdbcConn.jdbc_connection_schema) - - account_id = jdbcConn.account_id if jdbcConn.account_provider_id == Provider.AWS_CLOUD.value else admin_account_id - region = jdbcConn.region if jdbcConn.account_provider_id == Provider.AWS_CLOUD.value else admin_region + jdbc_targets = [] + create_connection_response = {} + # get_db_names(jdbcConn.jdbc_connection_url, jdbcConn.jdbc_connection_schema) + account_id, region = __get_admin_info(jdbcConn) + crawler_role_arn = __gen_role_arn(account_id=account_id, + region=region, + role_name='GlueDetectionJobRole') list = crud.list_jdbc_instance_source_by_instance_id_account(jdbcConn, account_id) if list: raise BizException(MessageEnum.SOURCE_JDBC_ALREADY_EXISTS.get_code(), MessageEnum.SOURCE_JDBC_ALREADY_EXISTS.get_msg()) - database_type = convert_provider_id_2_database_type(jdbcConn.account_provider_id) - glue_connection_name = f"{const.SOLUTION_NAME}-{database_type}-{jdbcConn.instance_id}" - # network_availability_zone by subnetId + glue_connection_name, glue_database_name, crawler_name = __gen_resources_name(jdbcConn) ec2_client, __ = __ec2(account=account_id, region=region) - # return availability_zone + glue = __get_glue_client(account=account_id, region=region) try: + if not jdbcConn.jdbc_connection_schema: + source: JdbcSource = JdbcSource(connection_url=jdbcConn.jdbc_connection_url, + username=jdbcConn.master_username, + password=jdbcConn.password, + secret_id=jdbcConn.secret, + ssl_verify_cert=True if jdbcConn.jdbc_enforce_ssl == "true" else False + ) + jdbcConn.jdbc_connection_schema = ("\n").join(list_jdbc_databases(source)) availability_zone = ec2_client.describe_subnets(SubnetIds=[jdbcConn.network_subnet_id])['Subnets'][0]['AvailabilityZone'] try: connectionProperties_dict = gen_conn_properties(jdbcConn) - response = __glue(account=account_id, region=region).create_connection( + create_connection_response = __glue(account=account_id, region=region).create_connection( CatalogId=account_id, ConnectionInput={ 'Name': glue_connection_name, @@ -1748,16 +1785,44 @@ def add_jdbc_conn(jdbcConn: JDBCInstanceSource): }, ) except ClientError as ce: - logger.error(traceback.format_exc()) if ce.response['Error']['Code'] == 'InvalidInputException': raise BizException(MessageEnum.SOURCE_JDBC_INPUT_INVALID.get_code(), MessageEnum.SOURCE_JDBC_INPUT_INVALID.get_msg()) - + if ce.response['Error']['Code'] == 'AlreadyExistsException': + raise BizException(MessageEnum.SOURCE_JDBC_ALREADY_EXISTS.get_code(), + MessageEnum.SOURCE_JDBC_ALREADY_EXISTS.get_msg()) except Exception as e: logger.error(traceback.format_exc()) - if response['ResponseMetadata']['HTTPStatusCode'] != 200: + if create_connection_response.get('ResponseMetadata', {}).get('HTTPStatusCode') != 200: raise BizException(MessageEnum.SOURCE_JDBC_CREATE_FAIL.get_code(), MessageEnum.SOURCE_JDBC_CREATE_FAIL.get_msg()) + # Creare Glue database + glue.create_database(DatabaseInput={'Name': glue_database_name}) + # Create Crawler + jdbc_source = JdbcSource(connection_url=jdbcConn.jdbc_connection_url, username=jdbcConn.master_username, password=jdbcConn.password, secret_id=jdbcConn.secret) + db_names = get_db_names_4_jdbc(jdbc_source, jdbcConn.jdbc_connection_schema) + for db_name in db_names: + trimmed_db_name = db_name.strip() + if trimmed_db_name: + jdbc_targets.append({ + 'ConnectionName': glue_connection_name, + 'Path': f"{trimmed_db_name}/%" + }) + try: + glue.create_crawler( + Name=crawler_name, + Role=crawler_role_arn, + DatabaseName=glue_database_name, + Targets={ + 'JdbcTargets': jdbc_targets, + }, + Tags={ + const.TAG_KEY: const.TAG_VALUE, + const.TAG_ADMIN_ACCOUNT_ID: admin_account_id + }, + ) + except Exception: + logger.error(traceback.format_exc()) jdbcConn.network_availability_zone = availability_zone jdbcConn.create_type = JDBCCreateType.ADD.value jdbc_conn_insert = JDBCInstanceSourceFullInfo() @@ -1772,7 +1837,6 @@ def add_jdbc_conn(jdbcConn: JDBCInstanceSource): jdbc_conn_insert.jdbc_enforce_ssl = jdbcConn.jdbc_enforce_ssl jdbc_conn_insert.kafka_ssl_enabled = jdbcConn.kafka_ssl_enabled jdbc_conn_insert.master_username = jdbcConn.master_username - # jdbc_conn_insert.password = jdbcConn.password jdbc_conn_insert.skip_custom_jdbc_cert_validation = jdbcConn.skip_custom_jdbc_cert_validation jdbc_conn_insert.custom_jdbc_cert = jdbcConn.custom_jdbc_cert jdbc_conn_insert.custom_jdbc_cert_string = jdbcConn.custom_jdbc_cert_string @@ -1784,8 +1848,9 @@ def add_jdbc_conn(jdbcConn: JDBCInstanceSource): jdbc_conn_insert.jdbc_driver_class_name = jdbcConn.jdbc_driver_class_name jdbc_conn_insert.jdbc_driver_jar_uri = jdbcConn.jdbc_driver_jar_uri jdbc_conn_insert.create_type = jdbcConn.create_type - # jdbc_conn_insert.connection_status = 'UNCONNECTED' jdbc_conn_insert.glue_connection = glue_connection_name + jdbc_conn_insert.glue_crawler = crawler_name + jdbc_conn_insert.glue_database = glue_database_name crud.add_jdbc_conn(jdbc_conn_insert) except ClientError as ce: logger.error(traceback.format_exc()) @@ -1798,9 +1863,6 @@ def add_jdbc_conn(jdbcConn: JDBCInstanceSource): else: raise BizException(MessageEnum.BIZ_UNKNOWN_ERR.get_code(), MessageEnum.BIZ_UNKNOWN_ERR.get_msg()) - except Exception as e: - raise BizException(MessageEnum.BIZ_UNKNOWN_ERR.get_code(), - MessageEnum.BIZ_UNKNOWN_ERR.get_msg()) def gen_conn_properties(jdbcConn): connectionProperties_dict = {} @@ -1826,7 +1888,7 @@ def gen_conn_properties(jdbcConn): def test_jdbc_conn(jdbc_conn_param: JDBCInstanceSourceBase): res = "FAIL" - account_id, region = gen_assume_info(jdbc_conn_param) + account_id, region = __get_admin_info(jdbc_conn_param) cursor = None connection = None # get connection name from sdp db @@ -1873,8 +1935,7 @@ def import_jdbc_conn(jdbc_conn: JDBCInstanceSourceBase): if crud.list_jdbc_connection_by_connection(jdbc_conn.instance_id): raise BizException(MessageEnum.SOURCE_JDBC_ALREADY_IMPORTED.get_code(), MessageEnum.SOURCE_JDBC_ALREADY_IMPORTED.get_msg()) - account_id = jdbc_conn.account_id if jdbc_conn.account_provider_id == Provider.AWS_CLOUD.value else admin_account_id - region = jdbc_conn.region if jdbc_conn.account_provider_id == Provider.AWS_CLOUD.value else admin_region + account_id, region = __get_admin_info() try: res_connection = __glue(account_id, region).get_connection(Name=jdbc_conn.instance_id)['Connection'] except ClientError as ce: @@ -1968,7 +2029,6 @@ def __create_jdbc_url(engine: str, host: str, port: str): # Add S3 bucket, SQS queues access policies def __update_access_policy_for_account(): s3_resource = boto3.session.Session().resource('s3') - # for cn_region in const.CN_REGIONS: # check if s3 bucket, sqs exists bucket_name = admin_bucket_name try: @@ -2063,7 +2123,7 @@ def __update_access_policy_for_account(): ], "Resource": f"arn:{partition}:s3:::{bucket_name}/glue-database/*" } - s3 = boto3.client('s3') + s3 = __s3_client """ :type : pyboto3.s3 """ try: restored_statements = [] @@ -2284,7 +2344,6 @@ def __list_rds_schema(account, region, credentials, instance_name, payload, rds_ logger.info(schema_path) return schema_path - def __delete_data_source_by_account(account_id: str, region: str): try: crud.delete_s3_bucket_source_by_account(account_id=account_id, region=region) @@ -2295,25 +2354,15 @@ def __delete_data_source_by_account(account_id: str, region: str): except Exception: logger.error(traceback.format_exc()) - def __delete_account(account_id: str, region: str): try: crud.delete_account_by_region(account_id=account_id, region=region) except Exception: logger.error(traceback.format_exc()) - def query_glue_connections(account: AccountInfo): - res = [] - list = [] - account_id = account.account_id if account.account_provider_id == Provider.AWS_CLOUD.value else admin_account_id - region = account.region if account.account_provider_id == Provider.AWS_CLOUD.value else admin_region - - # list = __glue(account=account_id, region=region).get_connections(CatalogId=account_id, - # Filter={'ConnectionType': 'JDBC'}, - # MaxResults=100, - # HidePassword=True)['ConnectionList'] - + res, list = [], [] + account_id, region = __get_admin_info(account) next_token = "" while True: @@ -2330,7 +2379,7 @@ def query_glue_connections(account: AccountInfo): if not next_token: break jdbc_list = query_jdbc_connections_sub_info() - jdbc_dict = {item[0]:f"{convert_provider_id_2_name(item[1])}-{item[2]}" for item in jdbc_list} + jdbc_dict = {item[0]: f"{convert_provider_id_2_name(item[1])}-{item[2]}" for item in jdbc_list} for item in list: if not item['Name'].startswith(const.SOLUTION_NAME): if item['Name'] in jdbc_dict: @@ -2342,7 +2391,7 @@ def query_jdbc_connections_sub_info(): return crud.query_jdbc_connections_sub_info() def list_buckets(account: AdminAccountInfo): - _, region = gen_assume_info(account) + _, region = __get_admin_info(account) iam_role_name = crud.get_iam_role(account.account_id) assumed_role = sts.assume_role(RoleArn=f"{iam_role_name}", RoleSessionName="glue-s3-connection") @@ -2359,54 +2408,46 @@ def query_glue_databases(account: AdminAccountInfo): return __glue(account=account.account_id, region=account.region).get_databases()['DatabaseList'] def query_account_network(account: AccountInfo): - accont_id = account.account_id if account.account_provider_id == Provider.AWS_CLOUD.value else admin_account_id - region = account.region if account.region == Provider.AWS_CLOUD.value else admin_region - logger.info(f'accont_id is:{accont_id},region is {region}') - ec2_client, __ = __ec2(account=accont_id, region=region) + account_id, region = __get_admin_info(account) + ec2_client, __ = __ec2(account=account_id, region=region) vpcs = query_all_vpc(ec2_client) - # vpcs = [vpc['VpcId'] for vpc in query_all_vpc(ec2_client)] vpc_list = [{"vpcId": vpc.get('VpcId'), "name": gen_resource_name(vpc)} for vpc in vpcs] - # vpc_list = [{"vpcId": vpc['VpcId'], "name": gen_resource_name(vpc)} for vpc in vpcs] if account.account_provider_id != Provider.AWS_CLOUD.value: res = __query_third_account_network(vpc_list, ec2_client) - logger.info(f"query_third_account_network res is {res}") return res else: return __query_aws_account_network(vpc_list, ec2_client) +# async def add_conn_jdbc_async(jdbcConn: JDBCInstanceSource): +# key = f"{jdbcConn.account_provider_id}/{jdbcConn.account_id}/{jdbcConn.region}" +# try: +# add_jdbc_conn(jdbcConn) +# return (key, "SUCCESSED", "") +# except Exception as e: +# return (key, "FAILED", str(e)) + def __query_third_account_network(vpc_list, ec2_client: any): try: - response = ec2_client.describe_security_groups(Filters=[ {'Name': 'vpc-id', 'Values': [vpc["vpcId"] for vpc in vpc_list]}, {'Name': 'group-name', 'Values': [const.SECURITY_GROUP_JDBC]} ]) vpc_ids = [item['VpcId'] for item in response['SecurityGroups']] subnets = ec2_client.describe_subnets(Filters=[{'Name': 'vpc-id', 'Values': [vpc_ids[0]]}])['Subnets'] - # private_subnet = list(filter(lambda x: not x["MapPublicIpOnLaunch"], subnets)) selected_subnet = subnets subnets_str_from_env = os.getenv('SubnetIds', '') if subnets_str_from_env: subnets_from_env = subnets_str_from_env.split(',') selected_subnet = [item for item in subnets if item.get('SubnetId') in subnets_from_env] - # target_subnet = private_subnet[0] if private_subnet else subnets[0] target_subnets = [{'subnetId': subnet["SubnetId"], 'arn': subnet["SubnetArn"], "subnetName": gen_resource_name(subnet)} for subnet in selected_subnet] vpc_info = ec2_client.describe_vpcs(VpcIds=[vpc_ids[0]])['Vpcs'][0] - return {"vpcs": [{'vpcId': vpc_info['VpcId'], - 'vpcName': [obj for obj in vpc_info['Tags'] if obj["Key"] == "Name"][0]["Value"], + tag = [obj for obj in vpc_info.get('Tags', []) if obj.get("Key") == "Name"] + return {"vpcs": [{'vpcId': vpc_info.get('VpcId'), + 'vpcName': tag[0].get("Value", "-") if len(tag) > 0 else "-", 'subnets': target_subnets, - 'securityGroups': [{'securityGroupId': response['SecurityGroups'][0]['GroupId'], - 'securityGroupName': response['SecurityGroups'][0]['GroupName']}]}] - } - # return {"vpcs": [{'vpcId': vpc_info['VpcId'], - # 'vpcName': [obj for obj in vpc_info['Tags'] if obj["Key"] == "Name"][0]["Value"], - # 'subnets': [{'subnetId': target_subnet['SubnetId'], - # 'arn': target_subnet['SubnetArn'], - # "subnetName": gen_resource_name(target_subnet) - # }], - # 'securityGroups': [{'securityGroupId': response['SecurityGroups'][0]['GroupId'], - # 'securityGroupName': response['SecurityGroups'][0]['GroupName']}]}] - # } + 'securityGroups': [{'securityGroupId': response.get('SecurityGroups',[])[0].get('GroupId'), + 'securityGroupName': response.get('SecurityGroups',[])[0].get('GroupName')}]}] + } except ClientError as ce: logger.error(traceback.format_exc()) if ce.response['Error']['Code'] == 'InvalidGroup.NotFound': @@ -2439,47 +2480,78 @@ def gen_resource_name(resource): else: return '-' -def gen_assume_info(account): - accont_id = account.account_id if account.account_provider_id == Provider.AWS_CLOUD.value else admin_account_id - region = account.region if account.region == Provider.AWS_CLOUD.value else admin_region - return accont_id, region - - def test_glue_conn(account, connection): return boto3.client('glue').start_connection_test( CatalogId=account, ConnectionName=connection )['ConnectionTest']['Status'] +# def list_data_location(): +# res = [] +# provider_list = crud.list_distinct_provider() +# for item in provider_list: +# regions: list[SourceRegion] = crud.list_distinct_region_by_provider(item.id) +# accounts_db = crud.get_account_list_by_provider(item.id) +# if not regions: +# continue +# if not accounts_db: +# continue +# for subItem in regions: +# accounts = [] +# for account in accounts_db: +# if account.region == subItem.region_name: +# accounts.append(account) +# if len(accounts) == 0: +# continue +# location = DataLocationInfo() +# location.account_count = len(accounts) +# location.source = item.provider_name +# location.region = subItem.region_name +# location.coordinate = subItem.region_cord +# location.region_alias = subItem.region_alias +# location.provider_id = item.id +# res.append(location) +# res = sorted(res, key=lambda x: x.account_count, reverse=True) +# return res def list_data_location(): res = [] - provider_list = crud.list_distinct_provider() - for item in provider_list: - regions: list[SourceRegion] = crud.list_distinct_region_by_provider(item.id) - accounts_db = crud.get_account_list_by_provider(item.id) - if not regions: - continue - if not accounts_db: - continue - for subItem in regions: - accounts = [] - for account in accounts_db: - if account.region == subItem.region_name: - accounts.append(account) - if len(accounts) == 0: - continue - location = DataLocationInfo() - location.account_count = len(accounts) - location.source = item.provider_name - location.region = subItem.region_name - location.coordinate = subItem.region_cord - location.region_alias = subItem.region_alias - location.provider_id = item.id - res.append(location) + provider_region_account_dict = {} + provider_region_detail_dict = {} + all_enable_accounts: List[Account] = crud.get_enable_account_list() + all_enable_regions: List[SourceRegion] = crud.get_enable_region_list() + if all_enable_regions: + for region in all_enable_regions: + provider_region_detail_dict[f"{region.provider_id}|{region.region_name}"] = {const.REGION_CORD: region.region_cord, + const.REGION_ALIAS: region.region_alias} + if all_enable_accounts: + for account in all_enable_accounts: + if account.account_provider_id == Provider.AWS_CLOUD.value or account.account_provider_id == Provider.JDBC_PROXY.value: + provider_region_account_dict = __gen_account_set(provider_region_account_dict, account, Provider.AWS_CLOUD.value) + else: + provider_region_account_dict = __gen_account_set(provider_region_account_dict, account, account.account_provider_id) + for product_region, account_set in provider_region_account_dict.items(): + provider_id = int(product_region.split("|")[0]) + location = DataLocationInfo() + location.account_count = len(account_set) + location.source = convert_provider_id_2_name(provider_id) + location.region = product_region.split("|")[1] + location.coordinate = provider_region_detail_dict.get(product_region, {}).get(const.REGION_CORD) + location.region_alias = provider_region_detail_dict.get(product_region, {}).get(const.REGION_ALIAS) + location.provider_id = provider_id + res.append(location) res = sorted(res, key=lambda x: x.account_count, reverse=True) return res +def __gen_account_set(details: dict, account: Account, key: int): + tmp_value = details.get(f"{key}|{account.region}") + if tmp_value: + tmp_value.add(account.account_id) + else: + tmp_value = set() + tmp_value.add(account.account_id) + details[f"{key}|{account.region}"] = tmp_value + return details def query_regions_by_provider(provider_id: int): return crud.query_regions_by_provider(provider_id) @@ -2508,6 +2580,25 @@ def query_full_provider_resource_infos(): def list_providers(): return crud.query_provider_list() +def get_db_names_4_jdbc(jdbc: JdbcSource, schemas: str): + if not __validate_jdbc_url(jdbc.connection_url): + raise BizException(MessageEnum.SOURCE_JDBC_URL_FORMAT_ERROR.get_code(), + MessageEnum.SOURCE_JDBC_URL_FORMAT_ERROR.get_msg()) + # list schemas + db_names = set() + if jdbc.connection_url.startswith('jdbc:mysql'): + schemas = list_jdbc_databases(jdbc) + return set(schemas) + else: + schema = get_schema_from_url(jdbc.connection_url) + if schema: + db_names.add(schema) + if schemas: + db_names.update(schemas.splitlines()) + if not db_names: + raise BizException(MessageEnum.SOURCE_JDBC_JDBC_NO_DATABASE.get_code(), + MessageEnum.SOURCE_JDBC_JDBC_NO_DATABASE.get_msg()) + return db_names def get_db_names(url: str, schemas: str): if not __validate_jdbc_url(url): @@ -2525,7 +2616,6 @@ def get_db_names(url: str, schemas: str): MessageEnum.SOURCE_JDBC_JDBC_NO_DATABASE.get_msg()) return db_names - def get_schema_from_url(url): for pattern in _jdbc_url_patterns: match = re.match(pattern, url) @@ -2568,7 +2658,7 @@ def grant_lake_formation_permission(credentials, crawler_role_arn, glue_database def query_connection_detail(account: JDBCInstanceSourceBase): - account_id, region = gen_assume_info(account) + account_id, region = __get_admin_info(account) source: JDBCInstanceSourceFullInfo = crud.get_jdbc_instance_source_glue(provider_id=account.account_provider_id, account=account.account_id, region=account.region, @@ -2581,29 +2671,356 @@ def query_connection_detail(account: JDBCInstanceSourceBase): conn['ConnectionProperties']['JDBC_CONNECTION_SCHEMA'] = source.jdbc_connection_schema return conn +def __gen_resources_name(jdbc): + database_type = convert_provider_id_2_database_type(jdbc.account_provider_id) + glue_connection_name = f"{const.SOLUTION_NAME}-{database_type}-{jdbc.instance_id}" + glue_database_name = glue_connection_name + crawler_name = f"{const.SOLUTION_NAME}-{database_type}-{jdbc.instance_id}" + return glue_connection_name, glue_database_name, crawler_name def __get_excludes_file_exts(): extensions = list(set([ext for extensions_list in const.UNSTRUCTURED_FILES.values() for ext in extensions_list])) return ["*.{" + ",".join(extensions) + "}"] +def __get_glue_client(account, region): + iam_role_name = crud.get_iam_role(account) + assumed_role = sts.assume_role( + RoleArn=f"{iam_role_name}", + RoleSessionName="glue-connection" + ) + credentials = assumed_role['Credentials'] + glue = boto3.client('glue', + aws_access_key_id=credentials['AccessKeyId'], + aws_secret_access_key=credentials['SecretAccessKey'], + aws_session_token=credentials['SessionToken'], + region_name=region + ) + return glue + +# def list_jdbc_databases(source: JdbcSource) -> list[str]: +# url_arr = source.connection_url.split(":") +# if len(url_arr) != 4: +# raise BizException(MessageEnum.SOURCE_JDBC_URL_FORMAT_ERROR.get_code(), MessageEnum.SOURCE_JDBC_URL_FORMAT_ERROR.get_msg()) +# if url_arr[1] != "mysql": +# raise BizException(MessageEnum.SOURCE_JDBC_LIST_DATABASES_NOT_SUPPORTED.get_code(), MessageEnum.SOURCE_JDBC_LIST_DATABASES_NOT_SUPPORTED.get_msg()) +# host = url_arr[2][2:] +# port = int(url_arr[3].split("/")[0]) +# user = source.username +# password = source.password +# if source.secret_id: +# secrets_client = boto3.client('secretsmanager') +# secret_response = secrets_client.get_secret_value(SecretId=source.secret_id) +# secrets = json.loads(secret_response['SecretString']) +# user = secrets['username'] +# password = secrets['password'] +# mysql_database = jdbc_database.MySQLDatabase(host, port, user, password) +# databases = mysql_database.list_databases() +# logger.info(databases) +# return databases + +def batch_create(file: UploadFile = File(...)): + res_column_index = 12 + time_str = time.time() + jdbc_from_excel_set = set() + created_jdbc_list = [] + account_set = set() + # Check if the file is an Excel file + if not file.filename.endswith('.xlsx'): + raise BizException(MessageEnum.SOURCE_BATCH_CREATE_FORMAT_ERR.get_code(), + MessageEnum.SOURCE_BATCH_CREATE_FORMAT_ERR.get_msg()) + # Read the Excel file + content = file.file.read() + workbook = openpyxl.load_workbook(BytesIO(content), read_only=False) + try: + sheet = workbook.get_sheet_by_name(const.BATCH_SHEET) + except KeyError: + raise BizException(MessageEnum.SOURCE_BATCH_SHEET_NOT_FOUND.get_code(), + MessageEnum.SOURCE_BATCH_SHEET_NOT_FOUND.get_msg()) + header = [cell for cell in sheet.iter_rows(min_row=2, max_row=2, values_only=True)][0] + sheet.delete_cols(12, amount=2) + sheet.insert_cols(12, amount=2) + sheet.cell(row=2, column=12, value="Result") + sheet.cell(row=2, column=13, value="Details") + accounts = crud.get_enable_account_list() + accounts_list = [f"{account[0]}/{account[1]}/{account[2]}" for account in accounts] + no_content = True + for row_index, row in enumerate(sheet.iter_rows(min_row=3), start=2): + if all(cell.value is None for cell in row): + continue + no_content = False + res, msg = __check_empty_for_field(row, header) + if res: + insert_error_msg_2_cells(sheet, row_index, msg, res_column_index) + elif sheet.cell(row=row_index + 1, column=2).value not in [0, 1]: + insert_error_msg_2_cells(sheet, row_index, f"The value of {header[1]} must be 0 or 1", res_column_index) + elif not __validate_jdbc_url(str(row[3].value)): + insert_error_msg_2_cells(sheet, row_index, f"The value of {header[3]} must be in the format jdbc:protocol://host:port", res_column_index) + elif not str(row[3].value).startswith('jdbc:mysql') and not row[4].value: + insert_error_msg_2_cells(sheet, row_index, f"Non-MySQL-type data source {header[4]} cannot be null", res_column_index) + elif len(str(row[2].value)) > const.CONNECTION_DESC_MAX_LEN: + insert_error_msg_2_cells(sheet, row_index, f"The value of {header[2]} must not exceed 2048", res_column_index) + elif f"{row[10].value}/{row[8].value}/{row[9].value}/{row[0].value}" in jdbc_from_excel_set: + insert_error_msg_2_cells(sheet, row_index, f"The value of {header[0]}, {header[8]}, {header[9]}, {header[10]} already exist in the preceding rows", res_column_index) + elif f"{row[10].value}/{row[8].value}/{row[9].value}" not in accounts_list: + # Account.account_provider_id, Account.account_id, Account.region + insert_error_msg_2_cells(sheet, row_index, "The account is not existed!", res_column_index) + else: + jdbc_from_excel_set.add(f"{row[10].value}/{row[8].value}/{row[9].value}/{row[0].value}") + account_set.add(f"{row[10].value}/{row[8].value}/{row[9].value}") + created_jdbc_list.append(__gen_created_jdbc(row)) + if no_content: + raise BizException(MessageEnum.SOURCE_BATCH_SHEET_NO_CONTENT.get_code(), + MessageEnum.SOURCE_BATCH_SHEET_NO_CONTENT.get_msg()) + # Query network info + if account_set: + account_info = list(account_set)[0].split("/") + network = query_account_network(AccountInfo(account_provider_id=account_info[0], account_id=account_info[1], region=account_info[2])) \ + .get('vpcs', [])[0] + vpc_id = network.get('vpcId') + subnets = [subnet.get('subnetId') for subnet in network.get('subnets')] + security_group_id = network.get('securityGroups', [])[0].get('securityGroupId') + created_jdbc_list = __map_network_jdbc(created_jdbc_list, subnets, security_group_id) + batch_result = asyncio.run(batch_add_conn_jdbc(created_jdbc_list)) + result = {f"{item[0]}/{item[1]}/{item[2]}/{item[3]}": f"{item[4]}/{item[5]}" for item in batch_result} + for row_index, row in enumerate(sheet.iter_rows(min_row=3), start=2): + if row[11].value: + continue + v = result.get(f"{row[10].value}/{row[8].value}/{row[9].value}/{row[0].value}") + if v: + if v.split('/')[0] == "SUCCESSED": + insert_success_2_cells(sheet, row_index, res_column_index) + else: + insert_error_msg_2_cells(sheet, row_index, v.split('/')[1], res_column_index) + # Write into excel + excel_bytes = BytesIO() + workbook.save(excel_bytes) + excel_bytes.seek(0) + # Upload to S3 + batch_create_ds = f"{const.BATCH_CREATE_REPORT_PATH}/report_{time_str}.xlsx" + __s3_client.upload_fileobj(excel_bytes, admin_bucket_name, batch_create_ds) + return f'report_{time_str}' + +def __check_empty_for_field(row, header): + if row[0].value is None or str(row[0].value).strip() == const.EMPTY_STR: + return True, f"{header[0]} should not be empty" + if row[1].value is None or str(row[1].value).strip() == const.EMPTY_STR: + return True, f"{header[1]} should not be empty" + if row[3].value is None or str(row[3].value).strip() == const.EMPTY_STR: + return True, f"{header[3]} should not be empty" + if row[5].value is None or str(row[5].value).strip() == const.EMPTY_STR: + if row[6].value is None or str(row[6].value).strip() == const.EMPTY_STR: + return True, f"{header[6]} should not be empty when {header[5]} is empty" + if row[7].value is None or str(row[7].value).strip() == const.EMPTY_STR: + return True, f"{header[7]} should not be empty when {header[5]} is empty" + if row[8].value is None or str(row[8].value).strip() == const.EMPTY_STR: + return True, f"{header[8]} should not be empty" + if row[9].value is None or str(row[9].value).strip() == const.EMPTY_STR: + return True, f"{header[9]} should not be empty" + if row[10].value is None or str(row[10].value).strip() == const.EMPTY_STR: + return True, f"{header[10]} should not be empty" + return False, None + +def __map_network_jdbc(created_jdbc_list, subnets, security_group_id): + res = [] + for index, item in enumerate(created_jdbc_list): + item.network_sg_id = security_group_id + item.network_subnet_id = subnets[0] if index % 2 == 0 else subnets[1] + res.append(item) + return res -def list_jdbc_databases(source: JdbcSource) -> list[str]: - url_arr = source.connection_url.split(":") - if len(url_arr) != 4: - raise BizException(MessageEnum.SOURCE_JDBC_URL_FORMAT_ERROR.get_code(), MessageEnum.SOURCE_JDBC_URL_FORMAT_ERROR.get_msg()) - if url_arr[1] != "mysql": - raise BizException(MessageEnum.SOURCE_JDBC_LIST_DATABASES_NOT_SUPPORTED.get_code(), MessageEnum.SOURCE_JDBC_LIST_DATABASES_NOT_SUPPORTED.get_msg()) - host = url_arr[2][2:] - port = int(url_arr[3].split("/")[0]) - user = source.username - password = source.password - if source.secret_id: - secrets_client = boto3.client('secretsmanager') - secret_response = secrets_client.get_secret_value(SecretId=source.secret_id) - secrets = json.loads(secret_response['SecretString']) - user = secrets['username'] - password = secrets['password'] - mysql_database = jdbc_database.MySQLDatabase(host, port, user, password) - databases = mysql_database.list_databases() - logger.info(databases) - return databases +def query_batch_status(filename: str): + success, warning, failed = 0, 0, 0 + file_key = f"{const.BATCH_CREATE_REPORT_PATH}/{filename}.xlsx" + response = __s3_client.list_objects_v2(Bucket=admin_bucket_name, Prefix=const.BATCH_CREATE_REPORT_PATH) + for obj in response.get('Contents', []): + if obj['Key'] == file_key: + response = __s3_client.get_object(Bucket=admin_bucket_name, Key=file_key) + excel_bytes = response['Body'].read() + workbook = openpyxl.load_workbook(BytesIO(excel_bytes)) + try: + sheet = workbook[const.BATCH_SHEET] + except KeyError: + raise BizException(MessageEnum.SOURCE_BATCH_SHEET_NOT_FOUND.get_code(), + MessageEnum.SOURCE_BATCH_SHEET_NOT_FOUND.get_msg()) + for _, row in enumerate(sheet.iter_rows(values_only=True, min_row=3)): + if row[11] == "FAILED": + failed += 1 + if row[11] == "SUCCESSED": + success += 1 + if row[11] == "WARNING": + warning += 1 + return {"success": success, "warning": warning, "failed": failed} + return 0 + +def download_batch_file(filename: str): + key = f'{const.BATCH_CREATE_REPORT_PATH}/{filename}.xlsx' + if filename.startswith("template-zh"): + key = const.BATCH_CREATE_TEMPLATE_PATH_CN + if filename.startswith("template-en"): + key = const.BATCH_CREATE_TEMPLATE_PATH_EN + url = __s3_client.generate_presigned_url( + ClientMethod="get_object", + Params={'Bucket': admin_bucket_name, 'Key': key}, + ExpiresIn=60 + ) + return url + +def __gen_created_jdbc(row): + created_jdbc = JDBCInstanceSource() + created_jdbc.instance_id = row[0].value + created_jdbc.jdbc_enforce_ssl = "true" if row[1].value == 1 else "false" + created_jdbc.description = str(row[2].value) if row[2].value else const.EMPTY_STR + created_jdbc.jdbc_connection_url = str(row[3].value) + created_jdbc.secret = str(row[5].value) if row[5].value else None + created_jdbc.master_username = str(row[6].value) if row[6].value else None + created_jdbc.password = str(row[7].value) if row[7].value else None + if row[4].value and row[4].value.strip() != const.EMPTY_STR: + created_jdbc.jdbc_connection_schema = str(row[4].value).replace(",", "\n") + else: + created_jdbc.jdbc_connection_schema = None + created_jdbc.account_id = str(row[8].value) + created_jdbc.region = str(row[9].value) + created_jdbc.account_provider_id = row[10].value + created_jdbc.creation_time = "" + created_jdbc.custom_jdbc_cert = "" + created_jdbc.custom_jdbc_cert_string = "" + created_jdbc.jdbc_driver_class_name = "" + created_jdbc.jdbc_driver_jar_uri = "" + created_jdbc.last_updated_time = "" + created_jdbc.network_availability_zone = "" + # created_jdbc.secret = "" + created_jdbc.skip_custom_jdbc_cert_validation = "false" + return created_jdbc + +async def batch_add_conn_jdbc(created_jdbc_list): + tasks = [asyncio.create_task(__add_jdbc_conn_batch(jdbc)) for jdbc in created_jdbc_list] + return await asyncio.gather(*tasks) + + +def batch_sync_jdbc(jdbc_list): + return asyncio.run(batch_sync_jdbc_manager(jdbc_list)) + +async def batch_sync_jdbc_manager(jdbc_list): + tasks = [asyncio.create_task(__batch_sync_jdbc_worker(jdbc)) for jdbc in jdbc_list] + return await asyncio.gather(*tasks) + +async def __batch_sync_jdbc_worker(jdbc): + sync_jdbc_connection(jdbc) + +def __gen_jdbc_targets_from_db_names(connection_name, db_names): + jdbc_targets = [] + for db_name in db_names: + trimmed_db_name = db_name.strip() + if trimmed_db_name: + jdbc_targets.append({ + 'ConnectionName': connection_name, + 'Path': f"{trimmed_db_name}/%" + }) + return jdbc_targets + +def __update_crawler(provider_id, account_id, instance, region, jdbc_targets, crawler_name, glue_database, crawler_role_arn): + assume_account, assume_region = __get_admin_info(JDBCInstanceSourceBase(account_provider_id=provider_id, + account_id=account_id, + instance_id=instance, + region=region)) + try: + __get_glue_client(assume_account, assume_region).update_crawler( + Name=crawler_name, + Role=crawler_role_arn, + DatabaseName=glue_database, + Targets={ + 'JdbcTargets': jdbc_targets, + }, + SchemaChangePolicy={ + 'UpdateBehavior': 'UPDATE_IN_DATABASE', + 'DeleteBehavior': 'DELETE_FROM_DATABASE' + } + ) + except Exception as e: + logger.error(traceback.format_exc()) + raise BizException(MessageEnum.BIZ_UNKNOWN_ERR.get_code(), + MessageEnum.BIZ_UNKNOWN_ERR.get_msg()) + +def __get_admin_info(jdbc): + account_id = jdbc.account_id if jdbc.account_provider_id == Provider.AWS_CLOUD.value else admin_account_id + region = jdbc.region if jdbc.account_provider_id == Provider.AWS_CLOUD.value else admin_region + return account_id, region + +async def __add_jdbc_conn_batch(jdbc: JDBCInstanceSource): + try: + add_jdbc_conn(jdbc) + return jdbc.account_provider_id, jdbc.account_id, jdbc.region, jdbc.instance_id, "SUCCESSED", None + except BizException as be: + return jdbc.account_provider_id, jdbc.account_id, jdbc.region, jdbc.instance_id, "FAILED", be.__msg__() + except Exception as e: + logger.info(e) + return jdbc.account_provider_id, jdbc.account_id, jdbc.region, jdbc.instance_id, "FAILED", str(e) + +# TBD +def batch_delete_resource(account: AccountInfo, datasource_list: List): + pass + +def export_datasource(key: str): + default_sheet = "Sheet" + sheets = {const.S3_STR: 0, const.RDS_STR: 1, const.GLUE_STR: 2, const.JDBC_STR: 3} + workbook = openpyxl.Workbook() + # session = get_session() + # with ThreadPoolExecutor(max_workers=4) as executor: + # for sheet, index in sheets.items(): + # executor.submit(__read_and_write, sheet, index, workbook, session) + for sheet, index in sheets.items(): + __read_and_write(sheet, index, workbook) + if default_sheet in workbook.sheetnames and len(workbook.sheetnames) > 1: + sheet = workbook[default_sheet] + workbook.remove(sheet) + workbook.active = 0 + file_name = f"datasource_{key}.xlsx" + tmp_file = f"{tempfile.gettempdir()}/{file_name}" + report_file = f"{const.DATASOURCE_REPORT}/{file_name}" + workbook.save(tmp_file) + stats = os.stat(tmp_file) + if stats.st_size < 6 * 1024 * 1024: + __s3_client.upload_file(tmp_file, admin_bucket_name, report_file) + else: + concurrent_upload(admin_bucket_name, report_file, tmp_file, __s3_client) + os.remove(tmp_file) + method_parameters = {'Bucket': admin_bucket_name, 'Key': report_file} + pre_url = __s3_client.generate_presigned_url( + ClientMethod="get_object", + Params=method_parameters, + ExpiresIn=60 + ) + return pre_url + +def delete_report(key: str): + __s3_client.delete_object(Bucket=admin_bucket_name, Key=f"{const.BATCH_CREATE_REPORT_PATH}/{key}.xlsx") + +def __read_and_write(sheet_name, index, workbook): + column_mappings = { + const.S3_STR: const.EXPORT_DS_HEADER_S3, + const.RDS_STR: const.EXPORT_DS_HEADER_RDS, + const.GLUE_STR: const.EXPORT_DS_HEADER_GLUE, + const.JDBC_STR: const.EXPORT_DS_HEADER_JDBC} + column_mapping = column_mappings[sheet_name] + # query = f"SELECT {', '.join(query_columns)} FROM {table_name}" + # result = engine.execute(query) + if sheet_name == const.S3_STR: + result = crud.get_datasource_from_s3() + elif sheet_name == const.RDS_STR: + result = crud.get_datasource_from_rds() + elif sheet_name == const.GLUE_STR: + result = crud.get_datasource_from_glue() + else: + result = crud.get_datasource_from_jdbc() + + if sheet_name in workbook.sheetnames: + sheet = workbook[sheet_name] + else: + sheet = workbook.create_sheet(sheet_name, index=index) + sheet.append(column_mapping) + + for row_num, row_data in enumerate(result, start=2): + for col_num, cell_value in enumerate(row_data, start=1): + if sheet_name == const.JDBC_STR and col_num == 1: + cell_value = convert_provider_id_2_database_type(cell_value) + sheet.cell(row=row_num, column=col_num).value = cell_value diff --git a/source/constructs/api/db/models_catalog.py b/source/constructs/api/db/models_catalog.py index 24db5b4f..4a406145 100644 --- a/source/constructs/api/db/models_catalog.py +++ b/source/constructs/api/db/models_catalog.py @@ -74,6 +74,8 @@ class CatalogDatabaseLevelClassification(Base): region = sa.Column(sa.String(20), nullable=False, info={'searchable': True}) database_type = sa.Column(sa.String(20), nullable=False) database_name = sa.Column(sa.String(255), nullable=False, info={'searchable': True}) + description = sa.Column(sa.String(2048), nullable=False, info={'searchable': True}) + url = sa.Column(sa.String(2048), nullable=False, info={'searchable': True}) privacy = sa.Column(sa.SmallInteger()) sensitivity = sa.Column(sa.String(255), nullable=False) object_count = sa.Column(sa.BigInteger()) diff --git a/source/constructs/api/db/models_discovery_job.py b/source/constructs/api/db/models_discovery_job.py index 16246f3e..310ed8ce 100644 --- a/source/constructs/api/db/models_discovery_job.py +++ b/source/constructs/api/db/models_discovery_job.py @@ -46,10 +46,10 @@ class DiscoveryJobDatabase(Base): id = sa.Column(sa.Integer(), autoincrement=True, primary_key=True) job_id = sa.Column(sa.Integer(), sa.ForeignKey('discovery_job.id'), nullable=False) - account_id = sa.Column(sa.String(20), nullable=False) - region = sa.Column(sa.String(20), nullable=False) - database_type = sa.Column(sa.String(20), nullable=False) - database_name = sa.Column(sa.String(255), nullable=False) + account_id = sa.Column(sa.String(20)) + region = sa.Column(sa.String(20)) + database_type = sa.Column(sa.String(20)) + database_name = sa.Column(sa.String(255)) table_name = sa.Column(sa.String(1000)) base_time = sa.Column(sa.DateTime()) version = sa.Column(sa.Integer()) diff --git a/source/constructs/api/discovery_job/crud.py b/source/constructs/api/discovery_job/crud.py index 9da241e5..4834effe 100644 --- a/source/constructs/api/discovery_job/crud.py +++ b/source/constructs/api/discovery_job/crud.py @@ -9,10 +9,9 @@ from sqlalchemy import func from common.constant import const import uuid -import datetime -from catalog.crud import get_catalog_database_level_classification_by_type_all,get_catalog_database_level_classification_by_params +from datetime import datetime +from data_source.resource_list import list_resources_by_database_type from template.service import get_template_snapshot_no -from tools.str_tool import is_empty def get_job(id: int) -> models.DiscoveryJob: @@ -104,15 +103,28 @@ def get_running_run(job_id: int) -> models.DiscoveryJobRun: return db_run +def __get_database_name_key(database_type: str) -> str: + if database_type == DatabaseType.S3.value: + return "bucket_name" + elif database_type == DatabaseType.RDS.value: + return "instance_id" + elif database_type == DatabaseType.GLUE.value: + return "glue_database_name" + else: + return "instance_id" + + def __add_job_databases(run: models.DiscoveryJobRun, database_type: str, base_time_dict: dict): - databases = get_catalog_database_level_classification_by_type_all(database_type).all() - for database in databases: - base_time = base_time_dict.get(f'{database.account_id}-{database.region}-{database_type}-{database.database_name}') + data_sources = list_resources_by_database_type(database_type).all() + database_name_key = __get_database_name_key(database_type) + for data_source in data_sources: + database_name = getattr(data_source, database_name_key) + base_time = base_time_dict.get(f'{data_source.account_id}-{data_source.region}-{database_type}-{database_name}') run_database = models.DiscoveryJobRunDatabase(run_id=run.id, - account_id=database.account_id, - region=database.region, + account_id=data_source.account_id, + region=data_source.region, database_type=database_type, - database_name=database.database_name, + database_name=database_name, base_time=base_time, state=RunDatabaseState.READY.value, uuid=uuid.uuid4().hex) @@ -167,21 +179,7 @@ def init_run(job_id: int) -> int: if job.all_jdbc == 1: __add_job_databases(run, job.database_type, base_time_dict) for job_database in job_databases: - if is_empty(job_database.database_name) and is_empty(job_database.table_name): - catalog_databases = get_catalog_database_level_classification_by_params(job_database.account_id,job_database.region,job_database.database_type).all() - for catalog_database in catalog_databases: - base_time = base_time_dict.get( - f'{job_database.account_id}-{job_database.region}-{job_database.database_type}-{catalog_database.database_name}', datetime.datetime.min) - run_database = models.DiscoveryJobRunDatabase(run_id=run.id, - account_id=job_database.account_id, - region=job_database.region, - database_type=job_database.database_type, - database_name=catalog_database.database_name, - base_time=base_time, - state=RunDatabaseState.READY.value, - uuid=uuid.uuid4().hex) - run.databases.append(run_database) - else: + if job_database.database_name: run_database = models.DiscoveryJobRunDatabase(run_id=run.id, account_id=job_database.account_id, region=job_database.region, @@ -192,6 +190,22 @@ def init_run(job_id: int) -> int: state=RunDatabaseState.READY.value, uuid=uuid.uuid4().hex) run.databases.append(run_database) + else: + data_sources = list_resources_by_database_type(job_database.database_type, job_database.account_id, job_database.region).all() + database_name_key = __get_database_name_key(job_database.database_type) + for data_source in data_sources: + database_name = getattr(data_source, database_name_key) + base_time = base_time_dict.get( + f'{job_database.account_id}-{job_database.region}-{job_database.database_type}-{database_name}', datetime.min) + run_database = models.DiscoveryJobRunDatabase(run_id=run.id, + account_id=job_database.account_id, + region=job_database.region, + database_type=job_database.database_type, + database_name=database_name, + base_time=base_time, + state=RunDatabaseState.READY.value, + uuid=uuid.uuid4().hex) + run.databases.append(run_database) session.add(run) session.commit() return run.id @@ -237,23 +251,18 @@ def count_run_databases(run_id: int): return db_count -def stop_run(job_id: int, run_id: int, stopping=False): +def stop_run(job_id: int, run_id: int): session = get_session() run_database_update = schemas.DiscoveryJobRunDatabaseUpdate() - run_database_update.end_time = mytime.get_time() - run_database_update.state = RunDatabaseState.STOPPING.value if stopping else RunDatabaseState.STOPPED.value + run_database_update.state = RunDatabaseState.STOPPING.value session.query(models.DiscoveryJobRunDatabase).filter(models.DiscoveryJobRunDatabase.run_id == run_id).update(run_database_update.dict(exclude_unset=True)) + run_update = schemas.DiscoveryJobRunUpdate() - run_update.end_time = mytime.get_time() - run_update.state = RunState.STOPPING.value if stopping else RunState.STOPPED.value + run_update.state = RunState.STOPPING.value session.query(models.DiscoveryJobRun).filter(models.DiscoveryJobRun.id == run_id).update(run_update.dict(exclude_unset=True)) + job: models.DiscoveryJob = session.query(models.DiscoveryJob).get(job_id) - job_state = JobState.OD_STOPPING.value if stopping else JobState.IDLE.value - if job.schedule == const.ON_DEMAND: - job_state = JobState.OD_STOPPING.value if stopping else JobState.OD_COMPLETED.value - job.state = job_state - if not stopping: - job.last_end_time = mytime.get_time() + job.state = JobState.OD_STOPPING.value session.commit() @@ -293,7 +302,7 @@ def get_run_database(run_database_id: int) -> models.DiscoveryJobRunDatabase: return session.query(models.DiscoveryJobRunDatabase).get(run_database_id) -def update_job_database_base_time(job_id: int, account_id: str, region: str, database_type: str, database_name: str, base_time: datetime.datetime): +def update_job_database_base_time(job_id: int, account_id: str, region: str, database_type: str, database_name: str, base_time: datetime): session = get_session() job_database = schemas.DiscoveryJobDatabaseBaseTime(base_time=base_time) session.query(models.DiscoveryJobDatabase).filter(models.DiscoveryJobDatabase.job_id == job_id, @@ -304,10 +313,17 @@ def update_job_database_base_time(job_id: int, account_id: str, region: str, dat session.commit() -def get_running_run_databases() -> list[models.DiscoveryJobRunDatabase]: +def get_run_databases_by_state(state: RunDatabaseState) -> list[models.DiscoveryJobRunDatabase]: session = get_session() - db_run_databases = session.query(models.DiscoveryJobRunDatabase).filter(models.DiscoveryJobRunDatabase.state == RunDatabaseState.RUNNING.value).all() - return db_run_databases + return session.query(models.DiscoveryJobRunDatabase).filter(models.DiscoveryJobRunDatabase.state == state.value).all() + + +def get_running_run_databases() -> list[models.DiscoveryJobRunDatabase]: + return get_run_databases_by_state(RunDatabaseState.RUNNING) + + +def get_pending_run_databases() -> list[models.DiscoveryJobRunDatabase]: + return get_run_databases_by_state(RunDatabaseState.PENDING) def count_account_run_job(account_id: str, regin: str): diff --git a/source/constructs/api/discovery_job/main.py b/source/constructs/api/discovery_job/main.py index 8f893f22..d47a6ae6 100644 --- a/source/constructs/api/discovery_job/main.py +++ b/source/constructs/api/discovery_job/main.py @@ -5,18 +5,10 @@ from common.response_wrapper import BaseResponse from common.query_condition import QueryCondition from fastapi_pagination.ext.sqlalchemy import paginate -from fastapi.responses import RedirectResponse -from common.constant import const router = APIRouter(prefix="/discovery-jobs", tags=["discovery-job"]) -# @router.get("/last-job-time", response_model=BaseResponse[str]) -# @inject_session -# def last_job_time(): -# return service.last_job_time() - - @router.post("", response_model=BaseResponse[schemas.DiscoveryJob]) @inject_session def create_job(job: schemas.DiscoveryJobCreate): @@ -32,85 +24,85 @@ def list_jobs(condition: QueryCondition): )) -@router.get("/{id}", response_model=BaseResponse[schemas.DiscoveryJob]) +@router.get("/{job_id}", response_model=BaseResponse[schemas.DiscoveryJob]) @inject_session -def get_job(id: int): - return service.get_job(id) +def get_job(job_id: int): + return service.get_job(job_id) -@router.delete("/{id}", response_model=BaseResponse[bool]) +@router.delete("/{job_id}", response_model=BaseResponse[bool]) @inject_session -def delete_job(id: int): - service.delete_job(id) +def delete_job(job_id: int): + service.delete_job(job_id) return True -@router.patch("/{id}", response_model=BaseResponse[bool]) +@router.patch("/{job_id}", response_model=BaseResponse[bool]) @inject_session -def update_job(id: int, job: schemas.DiscoveryJobUpdate): - service.update_job(id, job) +def update_job(job_id: int, job: schemas.DiscoveryJobUpdate): + service.update_job(job_id, job) return True -@router.post("/{id}/enable", response_model=BaseResponse[bool]) +@router.post("/{job_id}/enable", response_model=BaseResponse[bool]) @inject_session -def enable_job(id: int): - service.enable_job(id) +def enable_job(job_id: int): + service.enable_job(job_id) return True -@router.post("/{id}/disable", response_model=BaseResponse[bool]) +@router.post("/{job_id}/disable", response_model=BaseResponse[bool]) @inject_session -def disable_job(id: int): - service.disable_job(id) +def disable_job(job_id: int): + service.disable_job(job_id) return True -@router.post("/{id}/start", response_model=BaseResponse[bool]) +@router.post("/{job_id}/start", response_model=BaseResponse[bool]) @inject_session -def start_job(id: int): - service.start_job(id) +def start_job(job_id: int): + service.start_job(job_id) return True -@router.post("/{id}/stop", response_model=BaseResponse[bool]) +@router.post("/{job_id}/stop", response_model=BaseResponse[bool]) @inject_session -def stop_job(id: int): - service.stop_job(id) +def stop_job(job_id: int): + service.stop_job(job_id) return True -@router.get("/{id}/runs", response_model=BaseResponse[Page[schemas.DiscoveryJobRunList]]) +@router.get("/{job_id}/runs", response_model=BaseResponse[Page[schemas.DiscoveryJobRunList]]) @inject_session -def list_runs(id: int, params: Params = Depends()): - return paginate(service.get_runs(id), params) +def list_runs(job_id: int, params: Params = Depends()): + return paginate(service.get_runs(job_id), params) -@router.get("/{id}/runs/{run_id}", response_model=BaseResponse[schemas.DiscoveryJobRun]) +@router.get("/{job_id}/runs/{run_id}", response_model=BaseResponse[schemas.DiscoveryJobRun]) @inject_session -def get_run(id: int, run_id: int): +def get_run(job_id: int, run_id: int): return service.get_run(run_id) -@router.post("/{id}/runs/{run_id}/databases", response_model=BaseResponse[Page[schemas.DiscoveryJobRunDatabaseList]]) +@router.post("/{job_id}/runs/{run_id}/databases", response_model=BaseResponse[Page[schemas.DiscoveryJobRunDatabaseList]]) @inject_session -def list_run_databases(id: int, run_id: int, condition: QueryCondition): +def list_run_databases(job_id: int, run_id: int, condition: QueryCondition): return paginate(service.list_run_databases_pagination(run_id, condition), Params( size=condition.size, page=condition.page, )) -@router.get("/{id}/runs/{run_id}/status", response_model=BaseResponse[schemas.DiscoveryJobRunDatabaseStatus]) +@router.get("/{job_id}/runs/{run_id}/status", response_model=BaseResponse[schemas.DiscoveryJobRunDatabaseStatus]) @inject_session -def get_run_status(id: int, run_id: int): - return service.get_run_status(id, run_id) +def get_run_status(job_id: int, run_id: int): + return service.get_run_status(job_id, run_id) -@router.get("/{id}/runs/{run_id}/progress", response_model=BaseResponse[list[schemas.DiscoveryJobRunDatabaseProgress]]) +@router.get("/{job_id}/runs/{run_id}/progress", response_model=BaseResponse[list[schemas.DiscoveryJobRunDatabaseProgress]]) @inject_session -def get_run_progress(id: int, run_id: int): - return service.get_run_progress(id, run_id) +def get_run_progress(job_id: int, run_id: int): + return service.get_run_progress(job_id, run_id) # @router.get("/{id}/runs/{run_id}/report", @@ -127,15 +119,15 @@ def get_run_progress(id: int, run_id: int): # return RedirectResponse(url) -@router.get("/{id}/runs/{run_id}/report_url", response_model=BaseResponse[str]) +@router.get("/{job_id}/runs/{run_id}/report_url", response_model=BaseResponse[str]) @inject_session -def get_report_url(id: int, run_id: int): - url = service.get_report_url(run_id) +def get_report_url(job_id: int, run_id: int): + url = service.get_report_url(job_id, run_id) return url -@router.get("/{id}/runs/{run_id}/template_snapshot_url", response_model=BaseResponse[str]) +@router.get("/{job_id}/runs/{run_id}/template_snapshot_url", response_model=BaseResponse[str]) @inject_session -def get_template_snapshot_url(id: int, run_id: int): +def get_template_snapshot_url(job_id: int, run_id: int): url = service.get_template_snapshot_url(run_id) return url diff --git a/source/constructs/api/discovery_job/schemas.py b/source/constructs/api/discovery_job/schemas.py index 8a24d7d2..ae22b5fe 100644 --- a/source/constructs/api/discovery_job/schemas.py +++ b/source/constructs/api/discovery_job/schemas.py @@ -42,6 +42,7 @@ class DiscoveryJobRunDatabaseStatus(BaseModel): success_count: int fail_count: int ready_count: int + pending_count: int running_count: int stopped_count: int not_existed_count: int @@ -49,6 +50,7 @@ class DiscoveryJobRunDatabaseStatus(BaseModel): success_per: int fail_per: int ready_per: int + pending_per: int running_per: int stopped_per: int not_existed_per: int @@ -123,8 +125,8 @@ class DiscoveryJobState(BaseModel): class DiscoveryJobBase(BaseModel): name: str template_id: int = 1 - schedule: str = "cron(0 12 * * ? *)" - description: Optional[str] + schedule: str = "OnDemand" + description: Optional[str] = "" range: int = 1 depth_structured: int = 100 depth_unstructured: Optional[int] = 10 @@ -136,10 +138,10 @@ class DiscoveryJobBase(BaseModel): all_glue: Optional[int] all_jdbc: Optional[int] overwrite: Optional[int] - exclude_keywords: Optional[str] - include_keywords: Optional[str] - exclude_file_extensions: Optional[str] - include_file_extensions: Optional[str] + exclude_keywords: Optional[str] = "" + include_keywords: Optional[str] = "" + exclude_file_extensions: Optional[str] = "" + include_file_extensions: Optional[str] = "" provider_id: Optional[int] database_type: Optional[str] diff --git a/source/constructs/api/discovery_job/service.py b/source/constructs/api/discovery_job/service.py index c6c4862d..b23bd747 100644 --- a/source/constructs/api/discovery_job/service.py +++ b/source/constructs/api/discovery_job/service.py @@ -1,4 +1,5 @@ import os +import logging import boto3 import json import db.models_discovery_job as models @@ -7,23 +8,27 @@ from common.enum import MessageEnum, JobState, RunState, RunDatabaseState, DatabaseType, AthenaQueryState from common.constant import const from common.query_condition import QueryCondition -from common.reference_parameter import logger, admin_account_id, admin_region, admin_bucket_name, partition, url_suffix, public_account_id +from common.reference_parameter import logger, admin_account_id, admin_region, admin_bucket_name, partition, url_suffix, public_account_id, admin_subnet_ids import traceback import tools.mytime as mytime import datetime, time, pytz from openpyxl import Workbook from tempfile import NamedTemporaryFile from catalog.service import sync_job_detection_result -from tools.str_tool import is_empty -from common.abilities import need_change_account_id +from common.abilities import need_change_account_id, convert_database_type_2_provider, is_run_in_admin_vpc +import config.service as config_service +from data_source import jdbc_schema +from tools import list_tool +from data_source.resource_list import list_resources_by_database_type version = os.getenv(const.VERSION, '') controller_function_name = os.getenv("ControllerFunctionName", f"{const.SOLUTION_NAME}-Controller") sqs_resource = boto3.resource('sqs') -sql_result = "SELECT database_type,account_id,region,s3_bucket,s3_location,rds_instance_id,database_name,table_name,column_name,identifiers,sample_data FROM job_detection_output_table where run_id='%d' and privacy = 1" +sql_result = "SELECT database_type,account_id,region,database_name,location,column_name,identifiers,sample_data FROM job_detection_output_table where run_id='%d' and privacy = 1" sql_error = "SELECT account_id,region,database_type,database_name,table_name,error_message FROM job_detection_error_table where run_id='%d'" extra_py_files = f"s3://{admin_bucket_name}/job/script/job_extra_files.zip" +report_key_template = "report/report-%d-%d.xlsx" def list_jobs(condition: QueryCondition): @@ -45,10 +50,7 @@ def create_job(job: schemas.DiscoveryJobCreate): if job.depth_structured is None: job.depth_structured = 0 if job.depth_unstructured is None: - if job.database_type == DatabaseType.S3.value: - job.depth_unstructured = -1 # -1 represents all - else: - job.depth_unstructured = 0 + job.depth_unstructured = 0 db_job = crud.create_job(job) if db_job.schedule != const.ON_DEMAND: create_event(db_job.id, db_job.schedule) @@ -75,7 +77,7 @@ def create_event(job_id: int, schedule: str): ], ) - input = {"JobId": job_id} + input = {const.CONTROLLER_ACTION: const.CONTROLLER_ACTION_SCHEDULE_JOB, "JobId": job_id} response = client_events.put_targets( Rule=rule_name, Targets=[ @@ -195,7 +197,16 @@ def disable_job(id: int): def start_job(job_id: int): run_id = crud.init_run(job_id) if run_id >= 0: - __start_run(job_id, run_id) + run = crud.get_run(run_id) + if not run.databases: + crud.complete_run(run_id) + raise BizException(MessageEnum.DISCOVERY_JOB_DATABASE_IS_EMPTY.get_code(), + MessageEnum.DISCOVERY_JOB_DATABASE_IS_EMPTY.get_msg()) + failed_run_database_count = __start_run_databases(run.databases) + if failed_run_database_count == len(run.databases): + crud.complete_run(run_id) + raise BizException(MessageEnum.DISCOVERY_JOB_ALL_RUN_FAILED.get_code(), + MessageEnum.DISCOVERY_JOB_ALL_RUN_FAILED.get_msg()) def start_sample_job(job_id: int, table_name: str): @@ -205,14 +216,41 @@ def start_sample_job(job_id: int, table_name: str): __start_sample_run(job_id, run_id, table_name) -def __start_run(job_id: int, run_id: int): - job = crud.get_job(job_id) - run = crud.get_run(run_id) - run_databases = run.databases - if not run_databases: - crud.complete_run(run_id) - raise BizException(MessageEnum.DISCOVERY_JOB_DATABASE_IS_EMPTY.get_code(), - MessageEnum.DISCOVERY_JOB_DATABASE_IS_EMPTY.get_msg()) +def __get_job_number(database_type: str) -> int: + if database_type in [DatabaseType.S3.value, DatabaseType.GLUE.value]: + return int(config_service.get_config(const.CONFIG_SUB_JOB_NUMBER_S3, const.CONFIG_SUB_JOB_NUMBER_S3_DEFAULT_VALUE)) + return int(config_service.get_config(const.CONFIG_SUB_JOB_NUMBER_RDS, const.CONFIG_SUB_JOB_NUMBER_RDS_DEFAULT_VALUE)) + + +def get_run_database_ip_count(database_type: str) -> int: + crawler_ip = 0 + if database_type.startswith(DatabaseType.JDBC.value): + crawler_ip = 3 + return crawler_ip + __get_job_number(database_type) * 2 # Each GlueJob requires 2 IPs + + +def __count_run_database_by_subnet() -> dict: + count_run_database = {} + run_databases = crud.get_running_run_databases() + for run_database in run_databases: + if not need_change_account_id(run_database.database_type): + continue + provider_id = convert_database_type_2_provider(run_database.database_type) + _, subnet_id = jdbc_schema.get_schema_by_real_time(provider_id, run_database.account_id, run_database.region, run_database.database_name) + count = count_run_database.get(subnet_id, 0) + count_run_database[subnet_id] = count + 1 + logger.info(f"count_run_database:{count_run_database}") + return count_run_database + + +def __enable_event_bridge(rule_name: str): + client_events = boto3.client('events') + client_events.enable_rule(Name=rule_name) + + +def __start_run_databases(run_databases): + job_dic = {} + run_dic = {} module_path = f's3://{admin_bucket_name}/job/ml-asset/python-module/' wheels = ["humanfriendly-10.0-py2.py3-none-any.whl", "protobuf-4.22.1-cp37-abi3-manylinux2014_x86_64.whl", @@ -221,9 +259,12 @@ def __start_run(job_id: int, run_id: int): "onnxruntime-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", "sdpsner-1.0.0-py3-none-any.whl", ] + limit_concurrency = False account_loop_wait = {} for run_database in run_databases: + if is_run_in_admin_vpc(run_database.database_type, run_database.account_id): + limit_concurrency = True account_id = run_database.account_id if need_change_account_id(run_database.database_type): account_id = admin_account_id @@ -234,11 +275,55 @@ def __start_run(job_id: int, run_id: int): else: account_loop_wait[account_id] = const.JOB_INTERVAL_WAIT + if limit_concurrency: + concurrent_run_job_number = int(config_service.get_config(const.CONFIG_CONCURRENT_RUN_JOB_NUMBER, const.CONFIG_CONCURRENT_RUN_JOB_NUMBER_DEFAULT_VALUE)) + logger.debug(f"concurrent_run_job_number:{concurrent_run_job_number}") + count_run_database = __count_run_database_by_subnet() + for key in account_loop_wait: + if account_loop_wait[key] > const.JOB_INTERVAL_WAIT * concurrent_run_job_number: + account_loop_wait[key] = const.JOB_INTERVAL_WAIT * concurrent_run_job_number + job_placeholder = "," account_first_wait = {} - failed_run_count = 0 + failed_run_database_count = 0 + check_pending_started = False for run_database in run_databases: try: + if is_run_in_admin_vpc(run_database.database_type, run_database.account_id): + logger.debug(f"database_name:{run_database.database_name}") + provider_id = convert_database_type_2_provider(run_database.database_type) + database_schemas_real_time, subnet_id = jdbc_schema.get_schema_by_real_time(provider_id, run_database.account_id, run_database.region, run_database.database_name, True) + count = count_run_database.get(subnet_id, 0) + logger.debug(f"subnet_id:{subnet_id}") + logger.debug(f"count:{count}") + if count >= concurrent_run_job_number: + run_database.state = RunDatabaseState.PENDING.value + logger.debug(f"{run_database.database_name} break") + if not check_pending_started: + check_pending_started = True + __enable_event_bridge(f"{const.SOLUTION_NAME}-CheckPending") + continue + logger.debug(f"run_database.database_name add") + count_run_database[subnet_id] = count + 1 + if database_schemas_real_time: + database_schemas_snapshot, _ = jdbc_schema.get_schema_by_snapshot(provider_id, run_database.account_id, run_database.region, run_database.database_name) + logger.info(f'database_schemas_real_time:{database_schemas_real_time}') + logger.info(f'database_schemas_snapshot:{database_schemas_snapshot}') + if not list_tool.compare(database_schemas_real_time, database_schemas_snapshot): + jdbc_schema.sync_schema_by_job(provider_id, run_database.account_id, run_database.region, run_database.database_name, database_schemas_real_time) + logger.info(f'Updated schema:{database_schemas_real_time}') + else: + logger.info(f'Unable to obtain the schema for {run_database.database_name}') + + run = run_dic.get(run_database.run_id) + if not run: + run = crud.get_run(run_database.run_id) + run_dic[run_database.run_id] = run + job = job_dic.get(run.job_id) + if not job: + job = crud.get_job(run.job_id) + job_dic[run.job_id] = job + account_id = run_database.account_id region = run_database.region if need_change_account_id(run_database.database_type): @@ -255,7 +340,7 @@ def __start_run(job_id: int, run_id: int): if job.range == 1 and run_database.base_time: base_time = mytime.format_time(run_database.base_time) need_run_crawler = True - if run_database.database_type == DatabaseType.GLUE.value or not is_empty(run_database.table_name): + if run_database.database_type == DatabaseType.GLUE.value or run_database.table_name: need_run_crawler = False crawler_name = f"{const.SOLUTION_NAME}-{run_database.database_type}-{run_database.database_name}" glue_database_name = f"{const.SOLUTION_NAME}-{run_database.database_type}-{run_database.database_name}" @@ -263,9 +348,9 @@ def __start_run(job_id: int, run_id: int): glue_database_name = run_database.database_name job_name_structured = f"{const.SOLUTION_NAME}-{run_database.database_type}-{run_database.database_name}" job_name_unstructured = f"{const.SOLUTION_NAME}-{DatabaseType.S3_UNSTRUCTURED.value}-{run_database.database_name}" - run_name = f'{const.SOLUTION_NAME}-{run_id}-{run_database.id}-{run_database.uuid}' + run_name = f'{const.SOLUTION_NAME}-{run_database.run_id}-{run_database.id}-{run_database.uuid}' # agent_bucket_name = f"{const.AGENT_BUCKET_NAME_PREFIX}-{run_database.account_id}-{run_database.region}" - unstructured_parser_job_image_uri = f"{public_account_id}.dkr.ecr.{run_database.region}.amazonaws.com{url_suffix}/aws-sensitive-data-protection-models:v1.1.0" + unstructured_parser_job_image_uri = f"{public_account_id}.dkr.ecr.{run_database.region}.amazonaws.com{url_suffix}/aws-sensitive-data-protection-models:v1.1.2" unstructured_parser_job_role = f"arn:{partition}:iam::{run_database.account_id}:role/{const.SOLUTION_NAME}UnstructuredParserRole-{run_database.region}" execution_input = { "RunName": run_name, @@ -274,7 +359,7 @@ def __start_run(job_id: int, run_id: int): "NeedRunCrawler": need_run_crawler, "CrawlerName": crawler_name, "JobId": str(job.id), # When calling Glue Job using StepFunction, the parameter must be of string type - "RunId": str(run_id), + "RunId": str(run_database.run_id), "RunDatabaseId": str(run_database.id), "AccountId": run_database.account_id, # The original account id is required here "Region": run_database.region, # The original region is required here @@ -282,15 +367,15 @@ def __start_run(job_id: int, run_id: int): "DatabaseName": run_database.database_name, "GlueDatabaseName": glue_database_name, "UnstructuredDatabaseName": f"{const.SOLUTION_NAME}-{DatabaseType.S3_UNSTRUCTURED.value}-{run_database.database_name}", - "TableName": job_placeholder if is_empty(run_database.table_name) else run_database.table_name, + "TableName": run_database.table_name if run_database.table_name else job_placeholder, "TemplateId": str(run.template_id), "TemplateSnapshotNo": str(run.template_snapshot_no), "DepthStructured": "0" if run.depth_structured is None else str(run.depth_structured), "DepthUnstructured": "0" if run.depth_unstructured is None else str(run.depth_unstructured), - "ExcludeKeywords": job_placeholder if is_empty(run.exclude_keywords) else run.exclude_keywords, - "IncludeKeywords": job_placeholder if is_empty(run.include_keywords) else run.include_keywords, - "ExcludeFileExtensions": job_placeholder if is_empty(run.exclude_file_extensions) else run.exclude_file_extensions, - "IncludeFileExtensions": job_placeholder if is_empty(run.include_file_extensions) else run.include_file_extensions, + "ExcludeKeywords": run.exclude_keywords if run.exclude_keywords else job_placeholder, + "IncludeKeywords": run.include_keywords if run.include_keywords else job_placeholder, + "ExcludeFileExtensions": run.exclude_file_extensions if run.exclude_file_extensions else job_placeholder, + "IncludeFileExtensions": run.include_file_extensions if run.include_file_extensions else job_placeholder, "BaseTime": base_time, # "JobBookmarkOption": job_bookmark_option, "DetectionThreshold": str(job.detection_threshold), @@ -302,6 +387,7 @@ def __start_run(job_id: int, run_id: int): "ExtraPyFiles": extra_py_files, "FirstWait": str(account_first_wait[account_id]), "LoopWait": str(account_loop_wait[account_id]), + "JobNumber": __get_job_number(run_database.database_type), "QueueUrl": f'https://sqs.{region}.amazonaws.com{url_suffix}/{admin_account_id}/{const.SOLUTION_NAME}-DiscoveryJob', "UnstructuredParserJobImageUri": unstructured_parser_job_image_uri, "UnstructuredParserJobRole": unstructured_parser_job_role, @@ -313,17 +399,14 @@ def __start_run(job_id: int, run_id: int): __exec_run(execution_input) run_database.state = RunDatabaseState.RUNNING.value except Exception: - failed_run_count += 1 + failed_run_database_count += 1 msg = traceback.format_exc() run_database.state = RunDatabaseState.FAILED.value run_database.end_time = mytime.get_time() run_database.error_log = msg logger.exception("Run StepFunction exception:%s" % msg) crud.save_run_databases(run_databases) - if failed_run_count == len(run_databases): - crud.complete_run(run_id) - raise BizException(MessageEnum.DISCOVERY_JOB_ALL_RUN_FAILED.get_code(), - MessageEnum.DISCOVERY_JOB_ALL_RUN_FAILED.get_msg()) + return failed_run_database_count def __start_sample_run(job_id: int, run_id: int, table_name: str): @@ -404,6 +487,7 @@ def __create_job(database_type: str, account_id, region, database_name, job_name NumberOfWorkers=2, WorkerType='G.1X', ExecutionProperty={'MaxConcurrentRuns': 100}, + Timeout=30 * 24 * 60, Connections={'Connections': list(connection_set)}, ) else: @@ -418,6 +502,7 @@ def __create_job(database_type: str, account_id, region, database_name, job_name NumberOfWorkers=2, WorkerType='G.1X', ExecutionProperty={'MaxConcurrentRuns': 1000}, + Timeout=30 * 24 * 60, ) @@ -431,10 +516,11 @@ def __check_sfn_version(client_sfn, arn, account_id): logger.info(f"{account_id} version is:{agent_version}") # Only check if the solution version is consistent. # Do not determine if the build version is consistent - agent_solution_version = agent_version.split('-')[0] - if not version.startswith(agent_solution_version): - raise BizException(MessageEnum.DISCOVERY_JOB_AGENT_MISMATCHING_VERSION.get_code(), - MessageEnum.DISCOVERY_JOB_AGENT_MISMATCHING_VERSION.get_msg()) + if os.getenv(const.MODE) != const.MODE_DEV: + agent_solution_version = agent_version.split('-')[0] + if not version.startswith(agent_solution_version): + raise BizException(MessageEnum.DISCOVERY_JOB_AGENT_MISMATCHING_VERSION.get_code(), + MessageEnum.DISCOVERY_JOB_AGENT_MISMATCHING_VERSION.get_msg()) def __exec_run(execution_input): @@ -496,19 +582,19 @@ def stop_job(job_id: int): MessageEnum.DISCOVERY_JOB_STOPPING.get_msg()) run_databases: list[models.DiscoveryJobRunDatabase] = db_run.databases - crud.stop_run(job_id, db_run.id, True) + crud.stop_run(job_id, db_run.id) job = crud.get_job(job_id) for run_database in run_databases: logger.info(f"Stop job,JobId:{job_id},RunId:{run_database.run_id},RunDatabaseId:{run_database.id}," f"AccountId:{run_database.account_id},Region:{run_database.region}," f"DatabaseType:{run_database.database_type},DatabaseName:{run_database.database_name}") - __stop_run(run_database) + __stop_step_function(run_database) __send_complete_run_database_message(job_id, run_database.run_id, run_database.id, run_database.account_id, run_database.region, run_database.database_type, run_database.database_name, job.overwrite == 1, RunDatabaseState.STOPPED.value) -def __stop_run(run_database: models.DiscoveryJobRunDatabase): +def __stop_step_function(run_database: models.DiscoveryJobRunDatabase): account_id = run_database.account_id region = run_database.region if need_change_account_id(run_database.database_type): @@ -549,8 +635,8 @@ def list_run_databases_pagination(run_id: int, condition: QueryCondition): def get_run_status(job_id: int, run_id: int) -> schemas.DiscoveryJobRunDatabaseStatus: run_list = crud.list_run_databases(run_id) - total_count = success_count = fail_count = ready_count = running_count = stopped_count = not_existed_count = 0 - success_per = fail_per = ready_per = running_per = stopped_per = not_existed_per = 0 + success_count = fail_count = ready_count = pending_count = running_count = stopped_count = not_existed_count = 0 + success_per = fail_per = ready_per = pending_per = running_per = stopped_per = not_existed_per = 0 total_count = len(run_list) if total_count > 0: @@ -561,6 +647,8 @@ def get_run_status(job_id: int, run_id: int) -> schemas.DiscoveryJobRunDatabaseS fail_count += 1 elif run_item.state == RunDatabaseState.READY.value: ready_count += 1 + elif run_item.state == RunDatabaseState.PENDING.value: + pending_count += 1 elif run_item.state == RunDatabaseState.RUNNING.value: running_count += 1 elif run_item.state == RunDatabaseState.STOPPED.value: @@ -570,6 +658,7 @@ def get_run_status(job_id: int, run_id: int) -> schemas.DiscoveryJobRunDatabaseS fail_per = int(fail_count / total_count * 100) ready_per = int(ready_count / total_count * 100) + pending_per = int(pending_count / total_count * 100) running_per = int(running_count / total_count * 100) stopped_per = int(stopped_count / total_count * 100) not_existed_per = int(not_existed_count / total_count * 100) @@ -580,6 +669,7 @@ def get_run_status(job_id: int, run_id: int) -> schemas.DiscoveryJobRunDatabaseS success_count=success_count, fail_count=fail_count, ready_count=ready_count, + pending_count=pending_count, running_count=running_count, stopped_count=stopped_count, not_existed_count=not_existed_count, @@ -587,6 +677,7 @@ def get_run_status(job_id: int, run_id: int) -> schemas.DiscoveryJobRunDatabaseS success_per=success_per, fail_per=fail_per, ready_per=ready_per, + pending_per=pending_per, running_per=running_per, stopped_per=stopped_per, not_existed_per=not_existed_per @@ -601,6 +692,14 @@ def get_run_progress(job_id: int, run_id: int) -> list[schemas.DiscoveryJobRunDa run_progress = [] for run_database in run.databases: try: + if run_database.state == RunDatabaseState.PENDING.value: + progress = schemas.DiscoveryJobRunDatabaseProgress(run_database_id=run_database.id, + current_table_count=-1, + table_count=-1, + current_table_count_unstructured=-1, + table_count_unstructured=-1) + run_progress.append(progress) + continue base_time = datetime.datetime.min if job.range == 1 and run_database.base_time: base_time = run_database.base_time @@ -647,13 +746,15 @@ def get_run_progress(job_id: int, run_id: int) -> list[schemas.DiscoveryJobRunDa def __get_run_current_table_count(run_id: int): sql = f"select run_database_id,database_type,count(distinct table_name) from sdps_database.job_detection_output_table where run_id='{run_id}' group by run_database_id,database_type" current_table_count = __query_athena(sql) - logger.debug(current_table_count) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(current_table_count) table_count = {} for row in current_table_count[1:]: row_result = [__get_cell_value(cell) for cell in row] key = int(row_result[0]) if row_result[1] != DatabaseType.S3_UNSTRUCTURED.value else f"{row_result[0]}-{DatabaseType.S3_UNSTRUCTURED.value}" table_count[key] = int(row_result[2]) - logger.debug(table_count) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(table_count) return table_count @@ -688,14 +789,15 @@ def __get_table_count_from_agent(run_database: models.DiscoveryJobRunDatabase, b glue_database_name = f'{const.SOLUTION_NAME}-{DatabaseType.S3_UNSTRUCTURED.value}-{run_database.database_name}' elif run_database.database_type == DatabaseType.GLUE.value: glue_database_name = run_database.database_name - next_token = "" + next_token = None count = 0 logger.info(base_time) while True: try: - response = glue.get_tables( - DatabaseName=glue_database_name, - NextToken=next_token) + if next_token: + response = glue.get_tables(DatabaseName=glue_database_name, NextToken=next_token) + else: + response = glue.get_tables(DatabaseName=glue_database_name) except Exception as e: logger.exception(e) return -1 @@ -703,7 +805,7 @@ def __get_table_count_from_agent(run_database: models.DiscoveryJobRunDatabase, b if table.get('Parameters', {}).get('classification', '') != 'UNKNOWN' and table['UpdateTime'] > base_time: count += 1 next_token = response.get('NextToken') - if next_token is None: + if not next_token: break return count @@ -724,14 +826,40 @@ def __send_complete_run_database_message(job_id, run_id, run_database_id, accoun queue.send_message(MessageBody=json.dumps(event)) -def change_run_state(run_id: int): +def __publish_job_completed(run_id: int): + topic_arn = f'arn:{partition}:sns:{admin_region}:{admin_account_id}:{const.SOLUTION_NAME}-JobCompleted' + sns = boto3.client('sns') + + job = crud.get_job_by_run_id(run_id) + if not job: + return + + message = { + 'JobId': job.id, + 'RunId': run_id, + 'Message': f'{job.name} has been completed.' + } + + message_body = json.dumps(message) + try: + response = sns.publish( + TopicArn=topic_arn, + Message=message_body, + Subject=message['Message'], + ) + except Exception as e: + logger.exception(e) + + +def complete_run(run_id: int): run_databases = crud.count_run_databases(run_id) # If there are running tasks, the state is running. for state, count in run_databases: - if state == RunDatabaseState.RUNNING.value: + if state in [RunDatabaseState.RUNNING.value, RunDatabaseState.PENDING.value]: logger.info("There are also running tasks.") return crud.complete_run(run_id) + __publish_job_completed(run_id) def complete_run_database(input_event): @@ -765,7 +893,13 @@ def complete_run_database(input_event): logger.info(f'complete_run_database,JobId:{input_event["JobId"]},RunId:{input_event["RunId"]},DatabaseName:{input_event["DatabaseName"]}') -def check_running_run(): +def check_pending_run_databases(): + run_databases = crud.get_pending_run_databases() + if run_databases: + __start_run_databases(run_databases) + + +def check_running_run_databases(): run_databases = crud.get_running_run_databases() for run_database in run_databases: run_database_state, stop_time = __get_run_database_state_from_agent(run_database) @@ -783,10 +917,10 @@ def check_running_run(): if (datetime.datetime.now(pytz.timezone('UTC')) - stop_time).seconds < const.LAMBDA_MAX_RUNTIME: logger.info(f"run id:{run_database.run_id},run database id:{run_database.id} continue") continue - state = RunDatabaseState.SUCCEEDED.value + state, message = __get_run_log(run_database, False) elif run_database_state == RunDatabaseState.FAILED.value.upper(): state = RunDatabaseState.FAILED.value - message = __get_run_error_log(run_database) + _, message = __get_run_log(run_database) elif run_database_state == RunDatabaseState.ABORTED.value.upper(): state = RunDatabaseState.STOPPED.value @@ -826,7 +960,7 @@ def __get_run_database_state_from_agent(run_database: models.DiscoveryJobRunData return RunDatabaseState.NOT_EXIST.value, None -def __get_run_error_log(run_database: models.DiscoveryJobRunDatabase) -> str: +def __get_run_log(run_database: models.DiscoveryJobRunDatabase, error_log=True) -> (str, str): account_id = run_database.account_id region = run_database.region if need_change_account_id(run_database.database_type): @@ -845,27 +979,46 @@ def __get_run_error_log(run_database: models.DiscoveryJobRunDatabase) -> str: aws_session_token=credentials['SessionToken'], region_name=region, ) + max_results = 1 if error_log else 5 try: response = client_sfn.get_execution_history( executionArn=f'arn:{partition}:states:{region}:{account_id}:execution:{const.SOLUTION_NAME}-DiscoveryJob:{const.SOLUTION_NAME}-{run_database.run_id}-{run_database.id}-{run_database.uuid}', reverseOrder=True, - maxResults=1, + maxResults=max_results, ) except client_sfn.exceptions.ExecutionDoesNotExist as e: - return RunDatabaseState.NOT_EXIST.value - if response["events"][0]["type"] == "ExecutionFailed": - return response["events"][0]["executionFailedEventDetails"]["cause"] - return "" + return RunDatabaseState.NOT_EXIST.value, "" + if error_log: + if response["events"][0]["type"] == "ExecutionFailed": + return RunDatabaseState.FAILED.value, response["events"][0]["executionFailedEventDetails"]["cause"] + return RunDatabaseState.FAILED.value, "" + else: + parameters = json.loads(response["events"][4]["taskScheduledEventDetails"]["parameters"]) + result = parameters.get("MessageBody", {}).get("Result") + if result: + return result.get("State"), result.get("Message", "") + return RunDatabaseState.SUCCEEDED.value, "" def __get_cell_value(cell: dict): - if "VarCharValue" in cell: - return cell["VarCharValue"] - else: - return "" + return cell.get("VarCharValue", "") -def get_report_url(run_id: int): +def generate_report(job_id: int, run_id: int, s3_client=None, key_name=None): + logger.info(f"Gen job report,run id:{run_id}") + if not s3_client: + s3_client = boto3.client('s3') + if not key_name: + key_name = report_key_template % (job_id, run_id) + job = crud.get_job(job_id) + datasource_info = {} + # Starting from version v1.1, a job only has one database_type + if job.database_type.startswith(DatabaseType.JDBC.value): + data_sources = list_resources_by_database_type(job.database_type).all() + for data_source in data_sources: + datasource_key = f"{job.database_type}-{data_source.instance_id}" + datasource_info[datasource_key] = data_source + run_result = __query_athena(sql_result % run_id) wb = Workbook() @@ -876,32 +1029,30 @@ def get_report_url(run_id: int): ws_rds = wb.create_sheet("Amazon RDS") ws_jdbc = wb.create_sheet("JDBC") ws_glue = wb.create_sheet("Glue") - ws_s3_structured.append(["account_id", "region", "s3_bucket", "s3_location", "column_name", "identifiers", "sample_data"]) - ws_s3_unstructured.append(["account_id", "region", "s3_bucket", "s3_location", "identifiers", "sample_data"]) - ws_rds.append(["account_id", "region", "rds_instance_id", "table_name,", "column_name", "identifiers", "sample_data"]) - ws_jdbc.append(["type", "account_id", "region", "database_name", "table_name,", "column_name", "identifiers", "sample_data"]) - ws_glue.append(["account_id", "region", "database_name", "table_name,", "column_name", "identifiers", "sample_data"]) + ws_s3_structured.append(["account_id", "region", "bucket_name", "location", "column_name", "identifiers", "sample_data"]) + ws_s3_unstructured.append(["account_id", "region", "bucket_name", "location", "identifiers", "sample_data"]) + ws_rds.append(["account_id", "region", "instance_name", "table_name", "column_name", "identifiers", "sample_data"]) + ws_jdbc.append(["type", "account_id", "region", "instance_name", "description", "jdbc_url", "table_name", "column_name", "identifiers", "sample_data"]) + ws_glue.append(["account_id", "region", "database_name", "table_name", "column_name", "identifiers", "sample_data"]) for row in run_result[1:]: row_result = [__get_cell_value(cell) for cell in row] database_type = row_result[0] del row_result[0] if database_type == DatabaseType.S3.value: - del row_result[4:7] ws_s3_structured.append(row_result) elif database_type == DatabaseType.S3_UNSTRUCTURED.value: - del row_result[4:8] + del row_result[4] # Delete column_name field ws_s3_unstructured.append(row_result) elif database_type == DatabaseType.GLUE.value: - del row_result[2:5] ws_glue.append(row_result) elif database_type.startswith(DatabaseType.JDBC.value): - del row_result[2:5] + datasource_key = f"{database_type}-{row_result[2]}" + data_source = datasource_info[datasource_key] + row_result[3:3] = [data_source.description, data_source.jdbc_connection_url] row_result.insert(0, database_type[5:]) ws_jdbc.append(row_result) - else: - del row_result[5:6] - del row_result[2:4] + else: # RDS ws_rds.append(row_result) error_result = __query_athena(sql_error % run_id) @@ -914,10 +1065,27 @@ def get_report_url(run_id: int): filename = NamedTemporaryFile().name wb.save(filename) - s3_client = boto3.client('s3') - key_name = f"report/report-{run_id}.xlsx" s3_client.upload_file(filename, admin_bucket_name, key_name) os.remove(filename) + + +def __check_file_existence(s3_client, key_name): + try: + s3_client.head_object(Bucket=admin_bucket_name, Key=key_name) + return True + except Exception as e: + if e.response['Error']['Code'] == '404': + return False + else: + logger.info(e) + return False + + +def get_report_url(job_id: int, run_id: int): + s3_client = boto3.client('s3') + key_name = report_key_template % (job_id, run_id) + if not __check_file_existence(s3_client, key_name): + generate_report(job_id, run_id, s3_client, key_name) method_parameters = {'Bucket': admin_bucket_name, 'Key': key_name} pre_url = s3_client.generate_presigned_url( ClientMethod="get_object", diff --git a/source/constructs/api/lambda/auto_sync_data.py b/source/constructs/api/lambda/auto_sync_data.py index ae8cf581..28acf856 100644 --- a/source/constructs/api/lambda/auto_sync_data.py +++ b/source/constructs/api/lambda/auto_sync_data.py @@ -1,15 +1,12 @@ -import json import logging import time import boto3 from common.enum import AutoSyncDataAction, Provider from data_source.service import delete_account -from db.database import close_session, gen_session from common.reference_parameter import logger, admin_region, partition from botocore.exceptions import ClientError from common.constant import const -logger.setLevel(logging.INFO) client_sts = boto3.client('sts') @@ -37,14 +34,3 @@ def sync_data(input_event): else: break delete_account(Provider.AWS_CLOUD.value, agent_account_id, None) - - -def lambda_handler(event, context): - try: - gen_session() - for record in event['Records']: - payload = record["body"] - logger.info(payload) - sync_data(json.loads(payload)) - finally: - close_session() diff --git a/source/constructs/api/lambda/check_run.py b/source/constructs/api/lambda/check_run.py deleted file mode 100644 index 10a09ef5..00000000 --- a/source/constructs/api/lambda/check_run.py +++ /dev/null @@ -1,14 +0,0 @@ -import discovery_job.service as service -from db.database import gen_session, close_session -import logging -from common.reference_parameter import logger - -logger.setLevel(logging.INFO) - - -def lambda_handler(event, context): - try: - gen_session() - service.check_running_run() - finally: - close_session() diff --git a/source/constructs/api/lambda/controller.py b/source/constructs/api/lambda/controller.py index 015474d0..d835875a 100644 --- a/source/constructs/api/lambda/controller.py +++ b/source/constructs/api/lambda/controller.py @@ -1,16 +1,90 @@ -import discovery_job.service as service +import json +import discovery_job.service as discovery_job_service +import data_source.service as data_source_service from db.database import gen_session, close_session -import logging -from common.reference_parameter import logger +import logging.config +from common.constant import const +from . import auto_sync_data, sync_crawler_results +import re -logger.setLevel(logging.INFO) +logging.config.fileConfig('logging.conf', disable_existing_loggers=False) +logger = logging.getLogger(const.LOGGER_API) def lambda_handler(event, context): try: + logger.info(event) gen_session() - job_id = event["JobId"] - logger.info(f'JobId:{job_id}') - service.start_job(job_id) + if not event: + return + if "Records" in event: + __dispatch_message(event) + return + # In the old version, the only parameter for scheduled job was JobId + if "JobId" in event and len(event) == 1: + __schedule_job(event) + controller_action = event[const.CONTROLLER_ACTION] + if not controller_action: + return + if controller_action == const.CONTROLLER_ACTION_SCHEDULE_JOB: + __schedule_job(event) + elif controller_action == const.CONTROLLER_ACTION_CHECK_RUNNING_RUN_DATABASES: + discovery_job_service.check_running_run_databases() + elif controller_action == const.CONTROLLER_ACTION_CHECK_PENDING_RUN_DATABASES: + discovery_job_service.check_pending_run_databases() + elif controller_action == const.CONTROLLER_ACTION_REFRESH_ACCOUNT: + data_source_service.refresh_account() + else: + logger.error("Unknown action") finally: close_session() + + +def __schedule_job(event): + discovery_job_service.start_job(event["JobId"]) + + +def __replace_single_quotes(match): + return match.group(0).replace("'", "`") + + +def __deal_single_quotes(payload): + logger.info(payload) + updated_string = re.sub(r'".*?"', __replace_single_quotes, str(payload)) + payload = updated_string.replace("\'", "\"") + logger.debug(payload) + return payload + + +def __dispatch_message(event): + if event['Records'][0].get("EventSource") == "aws:sns": + __deal_sns(event) + else: + __deal_sqs(event) + + +def __deal_sns(event): + event_source = event['Records'][0]["EventSubscriptionArn"].split(":")[-2] + logger.info(f"event_source:{event_source}") + for record in event['Records']: + payload = record["Sns"]["Message"] + payload = __deal_single_quotes(payload) + current_event = json.loads(payload) + if event_source == f"{const.SOLUTION_NAME}-JobCompleted": + discovery_job_service.generate_report(int(current_event["JobId"]), int(current_event["RunId"])) + + +def __deal_sqs(event): + event_source = event['Records'][0]["eventSourceARN"].split(":")[-1] + logger.info(f"event_source:{event_source}") + for record in event['Records']: + payload = record["body"] + payload = __deal_single_quotes(payload) + current_event = json.loads(payload) + if event_source == f"{const.SOLUTION_NAME}-DiscoveryJob": + discovery_job_service.complete_run_database(current_event) + discovery_job_service.complete_run(int(current_event["RunId"])) + elif event_source == f"{const.SOLUTION_NAME}-AutoSyncData": + auto_sync_data.sync_data(current_event) + elif event_source == f"{const.SOLUTION_NAME}-Crawler": + sync_crawler_results.sync_result(current_event) diff --git a/source/constructs/api/lambda/forward_message.py b/source/constructs/api/lambda/forward_message.py index 86dc8b82..3b75ac2a 100644 --- a/source/constructs/api/lambda/forward_message.py +++ b/source/constructs/api/lambda/forward_message.py @@ -2,7 +2,7 @@ import logging import os -logger = logging.getLogger('forward_message') +logger = logging.getLogger('api') logger.setLevel(logging.INFO) admin_region = os.getenv("AdminRegion", "cn-northwest-1") sqs = boto3.resource('sqs', region_name=admin_region) diff --git a/source/constructs/api/lambda/portal_config.py b/source/constructs/api/lambda/portal_config.py index 42fb6c67..f9560f30 100644 --- a/source/constructs/api/lambda/portal_config.py +++ b/source/constructs/api/lambda/portal_config.py @@ -2,7 +2,7 @@ import os import json -logger = logging.getLogger('portal_config') +logger = logging.getLogger('api') logger.setLevel(logging.INFO) diff --git a/source/constructs/api/lambda/receive_job_info.py b/source/constructs/api/lambda/receive_job_info.py deleted file mode 100644 index d015537f..00000000 --- a/source/constructs/api/lambda/receive_job_info.py +++ /dev/null @@ -1,23 +0,0 @@ -import discovery_job.service as service -from db.database import gen_session, close_session -import json -import logging -from common.reference_parameter import logger - -logger.setLevel(logging.INFO) - - -def main(input_event): - service.complete_run_database(input_event) - service.change_run_state(int(input_event["RunId"])) - - -def lambda_handler(event, context): - try: - gen_session() - for record in event['Records']: - payload = record["body"] - logger.info(payload) - main(json.loads(payload)) - finally: - close_session() diff --git a/source/constructs/api/lambda/refresh_account.py b/source/constructs/api/lambda/refresh_account.py deleted file mode 100644 index 9e241a1b..00000000 --- a/source/constructs/api/lambda/refresh_account.py +++ /dev/null @@ -1,15 +0,0 @@ -import data_source.service as service -from db.database import gen_session, close_session -import logging -from common.reference_parameter import logger - -logger.setLevel(logging.INFO) - - -def lambda_handler(event, context): - try: - logger.info(event) - gen_session() - service.refresh_account() - finally: - close_session() diff --git a/source/constructs/api/lambda/sync_crawler_results.py b/source/constructs/api/lambda/sync_crawler_results.py index 1fb6da99..0002ec96 100644 --- a/source/constructs/api/lambda/sync_crawler_results.py +++ b/source/constructs/api/lambda/sync_crawler_results.py @@ -1,18 +1,12 @@ -import json -import logging -import re import traceback - import catalog.service as catalog_service import data_source.crud as data_source_crud from common.abilities import convert_database_type_2_provider from common.abilities import need_change_account_id from common.constant import const from common.enum import DatabaseType, ConnectionState -from db.database import gen_session, close_session from common.reference_parameter import logger -logger.setLevel(logging.INFO) crawler_prefixes = const.SOLUTION_NAME + "-" @@ -114,18 +108,3 @@ def sync_result(input_event): state=state ) logger.debug("update jdbc datasource finished") - - -def lambda_handler(event, context): - try: - gen_session() - for record in event['Records']: - payload = record["body"] - logger.info(payload) - updated_string = re.sub(r'("[^"]*?)(\'.*?\')([^"]*?")', r'\1--\3', str(payload)) - payload = updated_string.replace("\'", "\"") - sync_result(json.loads(payload)) - except Exception: - logger.error(traceback.format_exc()) - finally: - close_session() diff --git a/source/constructs/api/logging.conf b/source/constructs/api/logging.conf index 250b8107..5df98800 100644 --- a/source/constructs/api/logging.conf +++ b/source/constructs/api/logging.conf @@ -12,7 +12,7 @@ level=INFO handlers= [logger_api] -level=INFO +level=DEBUG handlers=consoleHandler qualname=api @@ -22,4 +22,5 @@ formatter=normalFormatter args=(sys.stdout,) [formatter_normalFormatter] +class=common.log_formatter.CustomFormatter format=%(asctime)s [%(levelname)s] %(filename)s %(funcName)s() L%(lineno)-4d %(message)s diff --git a/source/constructs/api/main.py b/source/constructs/api/main.py index ad4aa0e2..493b842b 100644 --- a/source/constructs/api/main.py +++ b/source/constructs/api/main.py @@ -16,8 +16,9 @@ from common.constant import const from template.main import router as template_router from version.main import router as version_router -from common.exception_handler import biz_exception from label.main import router as label_router +from config.main import router as config_router +from common.exception_handler import biz_exception from fastapi_pagination import add_pagination logging.config.fileConfig('logging.conf', disable_existing_loggers=False) @@ -169,6 +170,7 @@ def __online_validate(token, jwt_claims): return False +app.include_router(config_router) app.include_router(discovery_router) app.include_router(data_source_router) app.include_router(catalog_router) diff --git a/source/constructs/api/pytest/test_labels.py b/source/constructs/api/pytest/test_labels.py index 67932b70..a1c33a2f 100644 --- a/source/constructs/api/pytest/test_labels.py +++ b/source/constructs/api/pytest/test_labels.py @@ -143,7 +143,6 @@ def test_get_labels_by_one_database(mocker): } ) assert get_labels_by_one_database.status_code == 200 - print(get_labels_by_one_database.json()) assert 'status' in get_labels_by_one_database.json() assert get_labels_by_one_database.json()['status'] == 'success' diff --git a/source/constructs/api/pytest/test_query.py b/source/constructs/api/pytest/test_query.py index ec02ec60..7622930d 100644 --- a/source/constructs/api/pytest/test_query.py +++ b/source/constructs/api/pytest/test_query.py @@ -107,7 +107,6 @@ def test_tables(mocker): headers={"authorization": "Bearer fake_token"} ) assert get_tables.status_code == 200 - print(get_tables.json()) assert 'status' in get_tables.json() assert get_tables.json()['status'] == 'success' if __name__ == '__main__': diff --git a/source/constructs/api/requirements.txt b/source/constructs/api/requirements.txt index 35609f79..8cfd4bb3 100644 --- a/source/constructs/api/requirements.txt +++ b/source/constructs/api/requirements.txt @@ -1,6 +1,6 @@ boto3==1.28.70 pytz==2023.3 -fastapi==0.109.1 +fastapi==0.109.2 mangum==0.17.0 sqlalchemy==1.4.44 fastapi-pagination==0.12.11 @@ -10,4 +10,5 @@ sqlakeyset==1.0.1659142803 requests==2.31.0 urllib3==1.26.18 python-jose==3.3.0 -pydantic==1.10.13 \ No newline at end of file +pydantic==1.10.13 +python_multipart==0.0.6 \ No newline at end of file diff --git a/source/constructs/api/search/main.py b/source/constructs/api/search/main.py index 3b1f8a9b..b45b11a4 100644 --- a/source/constructs/api/search/main.py +++ b/source/constructs/api/search/main.py @@ -1,15 +1,22 @@ -from enum import Enum +import io import json +import time +import zipfile +from datetime import datetime, timezone +from enum import Enum from typing import Optional +import boto3 from fastapi import APIRouter +from fastapi.responses import StreamingResponse from pydantic import BaseModel from common.request_wrapper import inject_session from common.response_wrapper import BaseResponse from db.models_catalog import CatalogColumnLevelClassification, CatalogTableLevelClassification, \ CatalogDatabaseLevelClassification -from db.models_data_source import S3BucketSource, Account, DetectionHistory, RdsInstanceSource, JDBCInstanceSource, SourceGlueDatabase +from db.models_data_source import S3BucketSource, Account, DetectionHistory, RdsInstanceSource, JDBCInstanceSource, \ + SourceGlueDatabase from db.models_discovery_job import DiscoveryJob, DiscoveryJobDatabase, DiscoveryJobRun, DiscoveryJobRunDatabase from db.models_template import TemplateIdentifier, TemplateMapping from . import crud @@ -90,6 +97,7 @@ def filter_values(table: str, column: str, condition: str): values.append('Empty') return values + @router.post("/", response_model=BaseResponse) @inject_session def filter_values(query: Query): @@ -118,3 +126,36 @@ def tables(): for searchable_class in searchable: tables.append(searchable_class.__tablename__) return tables + + +@router.get("/download-logs", response_class=StreamingResponse) +def download_log_as_zip(): + filename = f"aws_sdps_cloudwatch_logs.zip" + logs = boto3.client('logs') + response = logs.describe_log_groups(logGroupNamePattern='APIAPIFunction') + log_group_names = [group['logGroupName'] for group in response['logGroups']] + end_time = int(time.time()) * 1000 + start_time = end_time - 1 * 24 * 60 * 60 * 1000 # recent 1 days logs + + zip_bytes = io.BytesIO() + with zipfile.ZipFile(zip_bytes, 'w') as zipf: + for log_group_name in log_group_names: + response = logs.filter_log_events( + logGroupName=log_group_name, + startTime=start_time, + endTime=end_time, + interleaved=True + ) + log_events = sorted(response['events'], key=lambda x: x['timestamp'], reverse=True) + log_file_name = f'{log_group_name}.txt' + log_content = [] + for event in log_events: + timestamp = event['timestamp'] / 1000 # to seconds + timestamp_str = datetime.fromtimestamp(timestamp, timezone.utc).isoformat() + log_content.append(f'{timestamp_str}\t {event["message"]}') + + zipf.writestr(log_file_name, '\n'.join(log_content)) + + zip_bytes.seek(0) + return StreamingResponse(zip_bytes, media_type="application/zip", + headers={"Content-Disposition": f"attachment; filename={filename}"}) diff --git a/source/constructs/api/template/crud.py b/source/constructs/api/template/crud.py index 83f28c8a..9f295bb2 100644 --- a/source/constructs/api/template/crud.py +++ b/source/constructs/api/template/crud.py @@ -265,8 +265,29 @@ def update_prop(id: int, prop: schemas.TemplateIdentifierProp): session.commit() return snapshot_no, get_identifier(id) - -def get_refs_by_prop(id:int): - +def get_refs_by_prop(id: int): return get_session().query(models.TemplateIdentifierPropRef).filter( models.TemplateIdentifierPropRef.prop_id == id).all() + +def get_all_identifiers(): + return get_session().query( + models.TemplateIdentifier.id, + models.TemplateIdentifier.name, + models.TemplateIdentifier.description, + models.TemplateIdentifier.rule, + models.TemplateIdentifier.header_keywords, + models.TemplateIdentifier.exclude_keywords, + models.TemplateIdentifier.max_distance, + models.TemplateIdentifier.min_occurrence, + ).filter(models.TemplateIdentifier.type == 1).all() + +def get_identifier_prop_mapping(): + return get_session().query( + models.TemplateIdentifier.id, + models.TemplateIdentifierProp.prop_type, + models.TemplateIdentifierProp.prop_name + ).filter(models.TemplateIdentifier.type == 1).outerjoin(models.TemplateIdentifierPropRef, + models.TemplateIdentifierPropRef.identifier_id == models.TemplateIdentifier.id + ).outerjoin(models.TemplateIdentifierProp, + models.TemplateIdentifierProp.id == models.TemplateIdentifierPropRef.prop_id + ).all() diff --git a/source/constructs/api/template/main.py b/source/constructs/api/template/main.py index aaea9424..eb849bf5 100644 --- a/source/constructs/api/template/main.py +++ b/source/constructs/api/template/main.py @@ -1,4 +1,5 @@ -from fastapi import APIRouter +from typing import List +from fastapi import APIRouter, File, UploadFile from fastapi_pagination import Page, Params from fastapi_pagination.ext.sqlalchemy import paginate from common.request_wrapper import inject_session @@ -118,3 +119,29 @@ def delete_prop(id: int): @inject_session def get_template_time(tid: int): return service.get_template_snapshot_no(tid) + + +@router.post("/export-identify", response_model=BaseResponse) +@inject_session +def export_identify(key: str): + return service.export_identify(key) + +@router.post("/delete-report", response_model=BaseResponse) +@inject_session +def delete_report(key: str): + return service.delete_report(key) + +@router.post("/batch-create", response_model=BaseResponse) +@inject_session +def batch_create(files: List[UploadFile] = File(...)): + return service.batch_create(files[0]) + +@router.post("/query-batch-status", response_model=BaseResponse) +@inject_session +def query_batch_status(batch: str): + return service.query_batch_status(batch) + +@router.post("/download-batch-file", response_model=BaseResponse) +@inject_session +def download_batch_file(filename: str): + return service.download_batch_file(filename) diff --git a/source/constructs/api/template/service.py b/source/constructs/api/template/service.py index 2bebf321..e2328a05 100644 --- a/source/constructs/api/template/service.py +++ b/source/constructs/api/template/service.py @@ -1,5 +1,12 @@ +import asyncio +from io import BytesIO import json +import os +import tempfile +import time import boto3 +from fastapi import File, UploadFile +import openpyxl from common.constant import const from common.response_wrapper import S3WrapEncoder from common.exception_handler import BizException @@ -7,11 +14,13 @@ from common.query_condition import QueryCondition from common.reference_parameter import admin_bucket_name from catalog.service_dashboard import get_database_by_identifier +from common import concurrent_upload2s3 +from common.abilities import insert_error_msg_2_cells, insert_success_2_cells from template import schemas, crud caller_identity = boto3.client('sts').get_caller_identity() - +__s3_client = boto3.client('s3') def get_identifiers(condition: QueryCondition): return crud.get_identifiers(condition) @@ -28,11 +37,6 @@ def get_identifiers_by_template(tid: int): def get_identifier(id: int): return crud.get_identifier(id) - -# def get_identifier_by_names(names: list): -# return crud.get_identifier_by_names(names) - - def create_identifier(identifier: schemas.TemplateIdentifier): res_list = crud.get_identify_by_name(identifier.name) if res_list: @@ -167,10 +171,230 @@ def sync_s3(snapshot_no): ) -def check_rule(identifier): +def check_rule(identifier: schemas.TemplateIdentifier): if identifier.type == IdentifierType.CUSTOM.value and identifier.rule == const.EMPTY_STR: raise BizException(MessageEnum.TEMPLATE_IDENTIFIER_RULES_EMPTY.get_code(), MessageEnum.TEMPLATE_IDENTIFIER_RULES_EMPTY.get_msg()) if identifier.type == IdentifierType.CUSTOM.value and identifier.header_keywords and '""' in identifier.header_keywords[1:-1].split(","): raise BizException(MessageEnum.TEMPLATE_HEADER_KEYWORDS_EMPTY.get_code(), MessageEnum.TEMPLATE_HEADER_KEYWORDS_EMPTY.get_msg()) + + +def export_identify(key): + default_sheet = "Sheet" + workbook = openpyxl.Workbook() + sheet = workbook.create_sheet(const.BATCH_SHEET, index=0) + if default_sheet in workbook.sheetnames and len(workbook.sheetnames) > 1: + origin_sheet = workbook[default_sheet] + workbook.remove(origin_sheet) + sheet.append(const.EXPORT_IDENTIFY_HEADER) + result = crud.get_all_identifiers() + props_list = crud.get_identifier_prop_mapping() + props_mapping = __convert_prop_list_2_mapping(props_list) + for row_num, row_data in enumerate(result, start=2): + category = props_mapping.get(f"{row_data[0]}-1") + label = props_mapping.get(f"{row_data[0]}-2") + # del row_data[0] + for col_num, cell_value in enumerate(row_data, start=1): + if col_num == 1: + continue + if col_num == 5 or col_num == 6: + sheet.cell(row=row_num, column=col_num - 1).value = cell_value.replace('\"', '')[1:-1] if cell_value else cell_value + else: + sheet.cell(row=row_num, column=col_num - 1).value = cell_value + if category: + sheet.cell(row=row_num, column=8, value=category) + if label: + sheet.cell(row=row_num, column=9, value=label) + workbook.active = 0 + file_name = f"identify_{key}.xlsx" + tmp_file = f"{tempfile.gettempdir()}/{file_name}" + report_file = f"{const.IDENTIFY_REPORT}/{file_name}" + workbook.save(tmp_file) + stats = os.stat(tmp_file) + if stats.st_size < 6 * 1024 * 1024: + __s3_client.upload_file(tmp_file, admin_bucket_name, report_file) + else: + concurrent_upload2s3(admin_bucket_name, report_file, tmp_file, __s3_client) + os.remove(tmp_file) + method_parameters = {'Bucket': admin_bucket_name, 'Key': report_file} + pre_url = __s3_client.generate_presigned_url( + ClientMethod="get_object", + Params=method_parameters, + ExpiresIn=60 + ) + return pre_url + +def __convert_prop_list_2_mapping(list): + mapping = {} + for item in list: + mapping[f"{item[0]}-{item[1]}"] = item[2] + return mapping + +def delete_report(key): + __s3_client.delete_object(Bucket=admin_bucket_name, Key=f"{const.BATCH_CREATE_IDENTIFIER_REPORT_PATH}/{key}.xlsx") + + +def query_batch_status(filename: str): + success, warning, failed = 0, 0, 0 + file_key = f"{const.BATCH_CREATE_IDENTIFIER_REPORT_PATH}/{filename}.xlsx" + response = __s3_client.list_objects_v2(Bucket=admin_bucket_name, Prefix=const.BATCH_CREATE_IDENTIFIER_REPORT_PATH) + for obj in response.get('Contents', []): + if obj['Key'] == file_key: + response = __s3_client.get_object(Bucket=admin_bucket_name, Key=file_key) + excel_bytes = response['Body'].read() + workbook = openpyxl.load_workbook(BytesIO(excel_bytes)) + try: + sheet = workbook[const.BATCH_SHEET] + except KeyError: + raise BizException(MessageEnum.SOURCE_BATCH_SHEET_NOT_FOUND.get_code(), + MessageEnum.SOURCE_BATCH_SHEET_NOT_FOUND.get_msg()) + for _, row in enumerate(sheet.iter_rows(values_only=True, min_row=3)): + if row[9] == "FAILED": + failed += 1 + if row[9] == "SUCCESSED": + success += 1 + if row[9] == "WARNING": + warning += 1 + return {"success": success, "warning": warning, "failed": failed} + return 0 + +def download_batch_file(filename: str): + key = f'{const.BATCH_CREATE_IDENTIFIER_REPORT_PATH}/{filename}.xlsx' + if filename.startswith("identifier-template-zh"): + key = const.BATCH_CREATE_IDENTIFIER_TEMPLATE_PATH_CN + if filename.startswith("identifier-template-en"): + key = const.BATCH_CREATE_IDENTIFIER_TEMPLATE_PATH_EN + url = __s3_client.generate_presigned_url( + ClientMethod="get_object", + Params={'Bucket': admin_bucket_name, 'Key': key}, + ExpiresIn=60 + ) + return url + +def batch_create(file: UploadFile = File(...)): + res_column_index = 10 + time_str = time.time() + identifier_from_excel_set = set() + created_identifier_list = [] + category_list = crud.get_props_by_type(1) + label_list = crud.get_props_by_type(2) + # Check if the file is an Excel file + if not file.filename.endswith('.xlsx'): + raise BizException(MessageEnum.SOURCE_BATCH_CREATE_FORMAT_ERR.get_code(), + MessageEnum.SOURCE_BATCH_CREATE_FORMAT_ERR.get_msg()) + # Read the Excel file + content = file.file.read() + workbook = openpyxl.load_workbook(BytesIO(content), read_only=False) + try: + sheet = workbook.get_sheet_by_name(const.BATCH_SHEET) + except KeyError: + raise BizException(MessageEnum.SOURCE_BATCH_SHEET_NOT_FOUND.get_code(), + MessageEnum.SOURCE_BATCH_SHEET_NOT_FOUND.get_msg()) + header = [cell for cell in sheet.iter_rows(min_row=2, max_row=2, values_only=True)][0] + sheet.delete_cols(10, amount=2) + sheet.insert_cols(10, amount=2) + sheet.cell(row=2, column=10, value="Result") + sheet.cell(row=2, column=11, value="Details") + identifiers = crud.get_all_identifiers() + identifier_list = [identifier[0] for identifier in identifiers] + no_content = True + for row_index, row in enumerate(sheet.iter_rows(min_row=3), start=2): + props = [] + if all(cell.value is None for cell in row): + continue + no_content = False + res, msg = __check_empty_for_field(row, header) + if res: + insert_error_msg_2_cells(sheet, row_index, msg, res_column_index) + elif f"{row[0].value}" in identifier_from_excel_set: + insert_error_msg_2_cells(sheet, row_index, f"The value of {header[0]} already exist in the preceding rows", res_column_index) + elif not row[3].value and not row[4].value: + # Content validation rules and title keywords validation rules cannot be empty at the same time. + insert_error_msg_2_cells(sheet, row_index, f"The value of {header[3]} and {header[4]} cannot be empty at the same time.", res_column_index) + elif not __is_pos_int_or_none(row[5]): + insert_error_msg_2_cells(sheet, row_index, f"The value of {header[5]} must be positive integer.", res_column_index) + elif not __is_pos_int_or_none(row[6]): + insert_error_msg_2_cells(sheet, row_index, f"The value of {header[6]} must be positive integer.", res_column_index) + elif row[7].value and not [category for category in category_list if category.prop_name.lower() == row[7].value.strip().lower()]: + insert_error_msg_2_cells(sheet, row_index, f"The value of {header[7]} is not existed in System, please take a check", res_column_index) + elif row[8].value and not [label for label in label_list if label.prop_name.lower() == row[8].value.strip().lower()]: + insert_error_msg_2_cells(sheet, row_index, f"The value of {header[8]} is not existed in System, please take a check", res_column_index) + elif row[0].value in identifier_list: + # Account.account_provider_id, Account.account_id, Account.region + insert_error_msg_2_cells(sheet, row_index, "A data identifier with the same name already exists", res_column_index) + else: + identifier_from_excel_set.add(row[0].value) + if row[7].value: + categories = [category for category in category_list if category.prop_name.lower() == row[7].value.strip().lower()] + props.append(categories[0].id) + if row[8].value: + labels = [label for label in label_list if label.prop_name.lower() == row[8].value.strip().lower()] + props.append(labels[0].id) + # account_set.add(f"{row[10].value}/{row[8].value}/{row[9].value}") + created_identifier_list.append(__gen_created_identifier(row, props)) + if no_content: + raise BizException(MessageEnum.SOURCE_BATCH_SHEET_NO_CONTENT.get_code(), + MessageEnum.SOURCE_BATCH_SHEET_NO_CONTENT.get_msg()) + batch_result = asyncio.run(batch_add_identifier(created_identifier_list)) + result = {item[0]: f"{item[1]}/{item[2]}" for item in batch_result} + for row_index, row in enumerate(sheet.iter_rows(min_row=3), start=2): + if row[10] and row[10].value: + continue + v = result.get(row[0].value) + if v: + if v.split('/')[0] == "SUCCESSED": + insert_success_2_cells(sheet, row_index, res_column_index) + else: + insert_error_msg_2_cells(sheet, row_index, v.split('/')[1], res_column_index) + # Write into excel + excel_bytes = BytesIO() + workbook.save(excel_bytes) + excel_bytes.seek(0) + # Upload to S3 + batch_create_ds = f"{const.BATCH_CREATE_IDENTIFIER_REPORT_PATH}/report_{time_str}.xlsx" + __s3_client.upload_fileobj(excel_bytes, admin_bucket_name, batch_create_ds) + return f'report_{time_str}' + +def __check_empty_for_field(row, header): + if row[0].value is None or str(row[0].value).strip() == const.EMPTY_STR: + return True, f"{header[0]} should not be empty" + # if row[2].value is None or str(row[2].value).strip() == const.EMPTY_STR: + # return True, f"{header[2]} should not be empty" + return False, None + +def __gen_created_identifier(row, props): + created_identifier = schemas.TemplateIdentifier() + created_identifier.name = row[0].value + created_identifier.description = str(row[1].value) + created_identifier.props = props + created_identifier.rule = row[2].value + created_identifier.header_keywords = json.dumps(str(row[3].value).split(","), ensure_ascii=False) if row[3].value else None + created_identifier.exclude_keywords = json.dumps(str(row[4].value).split(","), ensure_ascii=False) if row[4].value else None + created_identifier.max_distance = row[5].value + created_identifier.min_occurrence = row[6].value + created_identifier.type = IdentifierType.CUSTOM.value + return created_identifier + +def __is_pos_int_or_none(cell): + if not cell or not cell.value: + return True + try: + if int(cell.value) > 0: + return True + except Exception as e: + return False + return False + +async def batch_add_identifier(created_identifier_list): + tasks = [asyncio.create_task(__add_create_identifier_batch(identifier)) for identifier in created_identifier_list] + return await asyncio.gather(*tasks) + +async def __add_create_identifier_batch(identifier: schemas.TemplateIdentifier): + try: + create_identifier(identifier) + return identifier.name, "SUCCESSED", None + except BizException as be: + return identifier.name, "FAILED", be.__msg__() + except Exception as e: + return identifier.name, "FAILED", str(e) diff --git a/source/constructs/api/tools/list_tool.py b/source/constructs/api/tools/list_tool.py new file mode 100644 index 00000000..da877b50 --- /dev/null +++ b/source/constructs/api/tools/list_tool.py @@ -0,0 +1,4 @@ +def compare(list1: list, list2: list): + sorted_list1 = sorted(list1) + sorted_list2 = sorted(list2) + return sorted_list1 == sorted_list2 diff --git a/source/constructs/api/tools/str_tool.py b/source/constructs/api/tools/str_tool.py deleted file mode 100644 index b42bcc3d..00000000 --- a/source/constructs/api/tools/str_tool.py +++ /dev/null @@ -1,7 +0,0 @@ -from common.constant import const - - -def is_empty(in_str: str) -> bool: - if in_str is None or in_str == const.EMPTY_STR: - return True - return False diff --git a/source/constructs/config/batch_create/datasource/template/batch_create_jdbc_datasource-cn.xlsx b/source/constructs/config/batch_create/datasource/template/batch_create_jdbc_datasource-cn.xlsx new file mode 100644 index 00000000..90637ca2 Binary files /dev/null and b/source/constructs/config/batch_create/datasource/template/batch_create_jdbc_datasource-cn.xlsx differ diff --git a/source/constructs/config/batch_create/datasource/template/batch_create_jdbc_datasource-en.xlsx b/source/constructs/config/batch_create/datasource/template/batch_create_jdbc_datasource-en.xlsx new file mode 100644 index 00000000..cf7d22ea Binary files /dev/null and b/source/constructs/config/batch_create/datasource/template/batch_create_jdbc_datasource-en.xlsx differ diff --git a/source/constructs/config/batch_create/identifier/template/batch_create_identifier-cn.xlsx b/source/constructs/config/batch_create/identifier/template/batch_create_identifier-cn.xlsx new file mode 100644 index 00000000..51762dc1 Binary files /dev/null and b/source/constructs/config/batch_create/identifier/template/batch_create_identifier-cn.xlsx differ diff --git a/source/constructs/config/batch_create/identifier/template/batch_create_identifier-en.xlsx b/source/constructs/config/batch_create/identifier/template/batch_create_identifier-en.xlsx new file mode 100644 index 00000000..6bafa269 Binary files /dev/null and b/source/constructs/config/batch_create/identifier/template/batch_create_identifier-en.xlsx differ diff --git a/source/constructs/config/job/script/glue-job-unstructured.py b/source/constructs/config/job/script/glue-job-unstructured.py index 6a2944aa..8d3626c8 100644 --- a/source/constructs/config/job/script/glue-job-unstructured.py +++ b/source/constructs/config/job/script/glue-job-unstructured.py @@ -95,6 +95,7 @@ 'database_name': args['DatabaseName'], 'database_type': args['DatabaseType'], 'table_name': table['Name'], + 'location': basic_table_info['location'], 's3_location': basic_table_info['s3_location'], 's3_bucket': basic_table_info['s3_bucket'], 'rds_instance_id': basic_table_info['rds_instance_id'], diff --git a/source/constructs/config/job/script/glue-job.py b/source/constructs/config/job/script/glue-job.py index 43c3f736..7fa8d347 100644 --- a/source/constructs/config/job/script/glue-job.py +++ b/source/constructs/config/job/script/glue-job.py @@ -93,6 +93,7 @@ 'database_name': args['DatabaseName'], 'database_type': args['DatabaseType'], 'table_name': table['Name'], + 'location': basic_table_info['location'], 's3_location': basic_table_info['s3_location'], 's3_bucket': basic_table_info['s3_bucket'], 'rds_instance_id': basic_table_info['rds_instance_id'], diff --git a/source/constructs/config/job/script/job_extra_files.zip b/source/constructs/config/job/script/job_extra_files.zip index ad65e9e0..b895bc51 100644 Binary files a/source/constructs/config/job/script/job_extra_files.zip and b/source/constructs/config/job/script/job_extra_files.zip differ diff --git a/source/constructs/lib/admin-stack.ts b/source/constructs/lib/admin-stack.ts index 90d9608c..248494e5 100755 --- a/source/constructs/lib/admin-stack.ts +++ b/source/constructs/lib/admin-stack.ts @@ -209,6 +209,7 @@ export class AdminStack extends Stack { vpc: vpcStack.vpc, bucketName: bucketStack.bucket.bucketName, rdsClientSecurityGroup: rdsStack.clientSecurityGroup, + customDBSecurityGroup: vpcStack.customDBSecurityGroup, oidcIssuer: oidcIssuerValue, oidcClientId: oidcClientIdValue, }); diff --git a/source/constructs/lib/admin/alb-stack.ts b/source/constructs/lib/admin/alb-stack.ts index a21c5ea1..599383c0 100644 --- a/source/constructs/lib/admin/alb-stack.ts +++ b/source/constructs/lib/admin/alb-stack.ts @@ -165,7 +165,7 @@ export class AlbStack extends NestedStack { this.createApi(listener, props); this.createProtalConfig(listener, props); - this.createPortal(listener); + this.createPortal(listener, props); }; private setUrl(scope: Construct, dnsName: string, props: AlbProps, defaultPort: number) { @@ -225,6 +225,7 @@ export class AlbStack extends NestedStack { }); albSecurityGroup.addIngressRule(Peer.anyIpv4(), Port.tcp(port), 'rule of allow inbound traffic from server port'); albSecurityGroup.addIngressRule(Peer.anyIpv6(), Port.tcp(port), 'rule of allow inbound traffic from server port'); + Tags.of(albSecurityGroup).add(SolutionInfo.TAG_NAME, `${SolutionInfo.SOLUTION_NAME}-ALB`); return albSecurityGroup; } @@ -271,7 +272,7 @@ export class AlbStack extends NestedStack { Tags.of(portalConfigTargetGroup).add(SolutionInfo.TAG_NAME, 'PortalConfig'); } - private createPortal(listener: ApplicationListener) { + private createPortal(listener: ApplicationListener, props: AlbProps) { let portalFunction; if (BuildConfig.PortalRepository && BuildConfig.PortalTag) { portalFunction = new DockerImageFunction(this, 'PortalFunction', { @@ -280,6 +281,9 @@ export class AlbStack extends NestedStack { code: DockerImageCode.fromEcr(Repository.fromRepositoryArn(this, 'PortalRepository', BuildConfig.PortalRepository), { tagOrDigest: BuildConfig.PortalTag }), architecture: Architecture.X86_64, + environment: { + OidcIssuer: props.oidcIssuer, + }, }); } else { portalFunction = new DockerImageFunction(this, 'PortalFunction', { @@ -291,6 +295,9 @@ export class AlbStack extends NestedStack { platform: Platform.LINUX_AMD64, }), architecture: Architecture.X86_64, + environment: { + OidcIssuer: props.oidcIssuer, + }, }); } const portalTarget = [new LambdaTarget(portalFunction)]; diff --git a/source/constructs/lib/admin/api-stack.ts b/source/constructs/lib/admin/api-stack.ts index 1d8766e7..864443c9 100644 --- a/source/constructs/lib/admin/api-stack.ts +++ b/source/constructs/lib/admin/api-stack.ts @@ -36,18 +36,21 @@ import { Code, AssetCode, LayerVersion, - FunctionOptions, } from 'aws-cdk-lib/aws-lambda'; import { SqsEventSource } from 'aws-cdk-lib/aws-lambda-event-sources'; +import { Topic } from 'aws-cdk-lib/aws-sns'; +import { LambdaSubscription } from 'aws-cdk-lib/aws-sns-subscriptions'; import { Construct } from 'constructs'; import { SqsStack } from './sqs-stack'; import { BuildConfig } from '../common/build-config'; import { SolutionInfo } from '../common/solution-info'; +import { Alias } from 'aws-cdk-lib/aws-kms'; export interface ApiProps { readonly vpc: IVpc; readonly bucketName: string; readonly rdsClientSecurityGroup: SecurityGroup; + readonly customDBSecurityGroup: SecurityGroup; readonly oidcIssuer: string; readonly oidcClientId: string; } @@ -66,35 +69,49 @@ export class ApiStack extends Construct { this.apiLayer = this.createLayer(); this.code = Code.fromAsset(path.join(__dirname, '../../api'), { exclude: ['venv', 'pytest'] }); - this.createFunction('Controller', 'lambda.controller.lambda_handler', props, 20, `${SolutionInfo.SOLUTION_NAME}-Controller`); - this.apiFunction = this.createFunction('API', 'main.handler', props, 900); - const checkRunFunction = this.createFunction('CheckRun', 'lambda.check_run.lambda_handler', props, 600); - const checkRunRule = new events.Rule(this, 'CheckRunRule', { + const controllerFunction = this.createFunction('Controller', 'lambda.controller.lambda_handler', props, 900, `${SolutionInfo.SOLUTION_NAME}-Controller`); + + const checkRunningRule = new events.Rule(this, 'CheckRunningRule', { // ruleName: `${SolutionInfo.SOLUTION_NAME}-CheckRun`, schedule: events.Schedule.cron({ minute: '0/30' }), }); - checkRunRule.addTarget(new targets.LambdaFunction(checkRunFunction)); - Tags.of(checkRunRule).add(SolutionInfo.TAG_KEY, SolutionInfo.TAG_VALUE); + checkRunningRule.addTarget(new targets.LambdaFunction(controllerFunction, { + event: events.RuleTargetInput.fromObject({ Action: 'CheckRunningRunDatabases' }), + })); + Tags.of(checkRunningRule).add(SolutionInfo.TAG_KEY, SolutionInfo.TAG_VALUE); + const checkPendingRule = new events.Rule(this, 'CheckPendingRule', { + ruleName: `${SolutionInfo.SOLUTION_NAME}-CheckPending`, + schedule: events.Schedule.rate(Duration.minutes(1)), + enabled: false, + }); + checkPendingRule.addTarget(new targets.LambdaFunction(controllerFunction, { + event: events.RuleTargetInput.fromObject({ Action: 'CheckPendingRunDatabases' }), + })); + Tags.of(checkPendingRule).add(SolutionInfo.TAG_KEY, SolutionInfo.TAG_VALUE); - const receiveJobInfoFunction = this.createFunction('ReceiveJobInfo', 'lambda.receive_job_info.lambda_handler', props, 900); const discoveryJobSqsStack = new SqsStack(this, 'DiscoveryJobQueue', { name: 'DiscoveryJob', visibilityTimeout: 900 }); const discoveryJobEventSource = new SqsEventSource(discoveryJobSqsStack.queue); - receiveJobInfoFunction.addEventSource(discoveryJobEventSource); + controllerFunction.addEventSource(discoveryJobEventSource); - const updateCatalogFunction = this.createFunction('UpdateCatalog', 'lambda.sync_crawler_results.lambda_handler', props, 900); const crawlerSqsStack = new SqsStack(this, 'CrawlerQueue', { name: 'Crawler', visibilityTimeout: 900 }); const crawlerEventSource = new SqsEventSource(crawlerSqsStack.queue); - updateCatalogFunction.addEventSource(crawlerEventSource); + controllerFunction.addEventSource(crawlerEventSource); - const autoSyncDataFunction = this.createFunction('AutoSyncData', 'lambda.auto_sync_data.lambda_handler', props, 900); - // Set delivery delay to 10 minutes to wait for agent stack to be deleted const autoSyncDataSqsStack = new SqsStack(this, 'AutoSyncDataQueue', { name: 'AutoSyncData', visibilityTimeout: 900 }); const autoSyncDataEventSource = new SqsEventSource(autoSyncDataSqsStack.queue); - autoSyncDataFunction.addEventSource(autoSyncDataEventSource); + controllerFunction.addEventSource(autoSyncDataEventSource); + + this.createJobCompletedTopic(controllerFunction); + } - this.createFunction('RefreshAccount', 'lambda.refresh_account.lambda_handler', props, 60, `${SolutionInfo.SOLUTION_NAME}-RefreshAccount`); + private createJobCompletedTopic(controllerFunction: Function) { + const jobCompletedTopic = new Topic(this, 'JobCompleted', { + topicName: `${SolutionInfo.SOLUTION_NAME}-JobCompleted`, + masterKey: Alias.fromAliasName(this, 'MasterKey', 'alias/aws/sns'), + }); + jobCompletedTopic.addSubscription(new LambdaSubscription(controllerFunction)); } private createFunction(name: string, handler: string, props: ApiProps, timeout?: number, functionName?: string) { @@ -111,6 +128,7 @@ export class ApiStack extends Construct { vpcSubnets: props.vpc.selectSubnets({ subnetType: SubnetType.PRIVATE_WITH_EGRESS, }), + // securityGroups: [props.rdsClientSecurityGroup, props.customDBSecurityGroup], securityGroups: [props.rdsClientSecurityGroup], environment: { AdminBucketName: props.bucketName, @@ -138,6 +156,8 @@ export class ApiStack extends Construct { 'logs:CreateLogGroup', 'logs:CreateLogStream', 'logs:PutLogEvents', + 'logs:DescribeLogGroups', + 'logs:FilterLogEvents', ], resources: ['*'], }); @@ -145,18 +165,27 @@ export class ApiStack extends Construct { const functionStatement = new PolicyStatement({ effect: Effect.ALLOW, - actions: ['sqs:DeleteMessage', + actions: [ + 'lambda:AddPermission', + 'lambda:RemovePermission', + 'sqs:DeleteMessage', 'sqs:ChangeMessageVisibility', 'sqs:GetQueueUrl', - 'athena:StartQueryExecution', - 'events:EnableRule', 'sqs:SendMessage', 'sqs:ReceiveMessage', - 'events:PutRule', - 'athena:GetQueryResults', 'sqs:GetQueueAttributes', 'sqs:SetQueueAttributes', + 'secretsmanager:GetSecretValue', + 's3:PutObject', + 's3:DeleteObject', + 's3:GetObject', + 's3:GetBucketLocation', + 's3:PutBucketPolicy', + 's3:GetBucketPolicy', 's3:ListBucket', + 'athena:StartQueryExecution', + 'athena:GetQueryResults', + 'athena:GetQueryExecution', 'glue:CreateDatabase', 'glue:GetDatabase', 'glue:GetDatabases', @@ -172,23 +201,19 @@ export class ApiStack extends Construct { 'glue:GetPartition', 'glue:GetPartitions', 'glue:BatchGetPartition', - 's3:PutObject', - 's3:DeleteObject', - 's3:GetObject', - 's3:GetBucketLocation', - 's3:PutBucketPolicy', - 's3:GetBucketPolicy', + 'glue:TagResource', + 'events:EnableRule', + 'events:PutRule', 'events:TagResource', 'events:PutTargets', 'events:DeleteRule', - 'lambda:AddPermission', - 'secretsmanager:GetSecretValue', - 'athena:GetQueryExecution', 'events:RemoveTargets', - 'lambda:RemovePermission', 'events:UntagResource', - 'events:DisableRule'], - resources: [`arn:${Aws.PARTITION}:lambda:*:${Aws.ACCOUNT_ID}:function:*`, + 'events:DisableRule', + 'sns:Publish', + ], + resources: [ + `arn:${Aws.PARTITION}:lambda:*:${Aws.ACCOUNT_ID}:function:*`, `arn:${Aws.PARTITION}:sqs:${Aws.REGION}:${Aws.ACCOUNT_ID}:${SolutionInfo.SOLUTION_NAME}-DiscoveryJob`, `arn:${Aws.PARTITION}:sqs:${Aws.REGION}:${Aws.ACCOUNT_ID}:${SolutionInfo.SOLUTION_NAME}-Crawler`, `arn:${Aws.PARTITION}:sqs:${Aws.REGION}:${Aws.ACCOUNT_ID}:${SolutionInfo.SOLUTION_NAME}-AutoSyncData`, @@ -199,7 +224,9 @@ export class ApiStack extends Construct { `arn:${Aws.PARTITION}:glue:*:${Aws.ACCOUNT_ID}:table/${SolutionInfo.SOLUTION_GLUE_DATABASE}/*`, `arn:${Aws.PARTITION}:glue:*:${Aws.ACCOUNT_ID}:database/${SolutionInfo.SOLUTION_GLUE_DATABASE}`, `arn:${Aws.PARTITION}:glue:*:${Aws.ACCOUNT_ID}:catalog`, - `arn:${Aws.PARTITION}:events:*:${Aws.ACCOUNT_ID}:rule/${SolutionInfo.SOLUTION_NAME}-*`], + `arn:${Aws.PARTITION}:events:*:${Aws.ACCOUNT_ID}:rule/${SolutionInfo.SOLUTION_NAME}-*`, + `arn:${Aws.PARTITION}:sns:${Aws.REGION}:${Aws.ACCOUNT_ID}:${SolutionInfo.SOLUTION_NAME}-JobCompleted`, + ], }); apiRole.addToPolicy(functionStatement); diff --git a/source/constructs/lib/admin/database/1.1.0-1.1.2/20_update.sql b/source/constructs/lib/admin/database/1.1.0-1.1.2/20_update.sql new file mode 100644 index 00000000..15d5056b --- /dev/null +++ b/source/constructs/lib/admin/database/1.1.0-1.1.2/20_update.sql @@ -0,0 +1,11 @@ +INSERT INTO config (config_key, config_value) VALUES ('ConcurrentRunJobNumber','10'); +INSERT INTO config (config_key, config_value) VALUES ('SubJobNumberS3','80'); +INSERT INTO config (config_key, config_value) VALUES ('SubJobNumberRds','3'); + +alter table discovery_job_database modify account_id varchar(20) null; +alter table discovery_job_database modify region varchar(20) null; +alter table discovery_job_database modify database_type varchar(20) null; +alter table discovery_job_database modify database_name varchar(255) null; + +alter table catalog_database_level_classification add url varchar(2048) default null after database_name; +alter table catalog_database_level_classification add description varchar(2048) default null after database_name; diff --git a/source/constructs/lib/admin/database/1.1.0-1.1.2/99_version.sql b/source/constructs/lib/admin/database/1.1.0-1.1.2/99_version.sql new file mode 100644 index 00000000..88a69b0c --- /dev/null +++ b/source/constructs/lib/admin/database/1.1.0-1.1.2/99_version.sql @@ -0,0 +1 @@ +insert into version (value,description,create_by,create_time,modify_by,modify_time) values ('1.1.2','upgrade install','System',now(),'System',now()); \ No newline at end of file diff --git a/source/constructs/lib/admin/database/init_db.py b/source/constructs/lib/admin/database/init_db.py index 59ea0af4..b2096129 100644 --- a/source/constructs/lib/admin/database/init_db.py +++ b/source/constructs/lib/admin/database/init_db.py @@ -63,6 +63,7 @@ def on_update(event): main(event,"1.0.0-1.0.1") main(event,"1.0.1-1.0.2") main(event,"1.0.x-1.1.0") + main(event,"1.1.0-1.1.2") def on_delete(event): diff --git a/source/constructs/lib/admin/database/whole/10_catalog.sql b/source/constructs/lib/admin/database/whole/10_catalog.sql index 0354fd0f..91ba6cb9 100644 --- a/source/constructs/lib/admin/database/whole/10_catalog.sql +++ b/source/constructs/lib/admin/database/whole/10_catalog.sql @@ -74,6 +74,8 @@ create table catalog_database_level_classification region varchar(20) not null, database_type varchar(20) not null, database_name varchar(255) not null, + description varchar(2048) null, + url varchar(2048) null, privacy smallint null, sensitivity varchar(255) not null, object_count bigint null, diff --git a/source/constructs/lib/admin/database/whole/10_data_source.sql b/source/constructs/lib/admin/database/whole/10_data_source.sql index 4b4d9702..1c9ad124 100644 --- a/source/constructs/lib/admin/database/whole/10_data_source.sql +++ b/source/constructs/lib/admin/database/whole/10_data_source.sql @@ -288,4 +288,3 @@ create table source_s3_bucket create index detection_history_id on source_s3_bucket (detection_history_id); - diff --git a/source/constructs/lib/admin/database/whole/10_discovery_job.sql b/source/constructs/lib/admin/database/whole/10_discovery_job.sql index 1e4e59e8..35aef042 100644 --- a/source/constructs/lib/admin/database/whole/10_discovery_job.sql +++ b/source/constructs/lib/admin/database/whole/10_discovery_job.sql @@ -38,10 +38,10 @@ create table discovery_job_database id int auto_increment primary key, job_id int not null, - account_id varchar(20) not null, - region varchar(20) not null, - database_type varchar(20) not null, - database_name varchar(255) not null, + account_id varchar(20) null, + region varchar(20) null, + database_type varchar(20) null, + database_name varchar(255) null, table_name varchar(1000) null, base_time datetime null, version int null, diff --git a/source/constructs/lib/admin/database/whole/90_init.sql b/source/constructs/lib/admin/database/whole/90_init.sql index 8614670b..2837066b 100644 --- a/source/constructs/lib/admin/database/whole/90_init.sql +++ b/source/constructs/lib/admin/database/whole/90_init.sql @@ -1,3 +1,7 @@ +-- Config +INSERT INTO config (config_key, config_value) VALUES ('ConcurrentRunJobNumber','10'); +INSERT INTO config (config_key, config_value) VALUES ('SubJobNumberS3','80'); +INSERT INTO config (config_key, config_value) VALUES ('SubJobNumberRds','3'); -- Template DELETE FROM template WHERE id=1; INSERT INTO template (id, name, snapshot_no, status, version, create_by, create_time, modify_by, modify_time) VALUES (1, 'default-template', 'init', 1, 1, null, null, null, null); diff --git a/source/constructs/lib/admin/database/whole/99_version.sql b/source/constructs/lib/admin/database/whole/99_version.sql index c5607a4f..e8cd05ff 100644 --- a/source/constructs/lib/admin/database/whole/99_version.sql +++ b/source/constructs/lib/admin/database/whole/99_version.sql @@ -1,4 +1,5 @@ insert into version (value,description,create_by,create_time,modify_by,modify_time) values ('1.0.0','whole install','System',now(),'System',now()); insert into version (value,description,create_by,create_time,modify_by,modify_time) values ('1.0.1','whole install','System',now(),'System',now()); insert into version (value,description,create_by,create_time,modify_by,modify_time) values ('1.0.2','whole install','System',now(),'System',now()); -insert into version (value,description,create_by,create_time,modify_by,modify_time) values ('1.1.0','whole install','System',now(),'System',now()); \ No newline at end of file +insert into version (value,description,create_by,create_time,modify_by,modify_time) values ('1.1.0','whole install','System',now(),'System',now()); +insert into version (value,description,create_by,create_time,modify_by,modify_time) values ('1.1.2','whole install','System',now(),'System',now()); \ No newline at end of file diff --git a/source/constructs/lib/admin/delete-resources-stack.ts b/source/constructs/lib/admin/delete-resources-stack.ts index d1bf3a46..7c9a4521 100644 --- a/source/constructs/lib/admin/delete-resources-stack.ts +++ b/source/constructs/lib/admin/delete-resources-stack.ts @@ -80,7 +80,7 @@ export class DeleteResourcesStack extends Construct { ], resources: [ `arn:${Aws.PARTITION}:events:*:${Aws.ACCOUNT_ID}:rule/*`, - `arn:${Aws.PARTITION}:lambda:${Aws.REGION}:${Aws.ACCOUNT_ID}:function:${SolutionInfo.SOLUTION_NAME}-RefreshAccount`, + `arn:${Aws.PARTITION}:lambda:${Aws.REGION}:${Aws.ACCOUNT_ID}:function:${SolutionInfo.SOLUTION_NAME}-Controller`, ], }); deleteAdminResourcesRole.addToPolicy(noramlStatement); diff --git a/source/constructs/lib/admin/delete-resources/delete_resources.py b/source/constructs/lib/admin/delete-resources/delete_resources.py index 687e9f23..5f5f2da5 100644 --- a/source/constructs/lib/admin/delete-resources/delete_resources.py +++ b/source/constructs/lib/admin/delete-resources/delete_resources.py @@ -72,9 +72,9 @@ def on_delete(event): def refresh_account(event): - response = lambda_client.invoke(FunctionName=f'{solution_name}-RefreshAccount', - Payload='{"UpdateEvent":"RerefshAccount"}') - print(response) + response = lambda_client.invoke(FunctionName=f'{solution_name}-Controller', + Payload='{"Action":"RefreshAccount"}') + logger.info(response) def __do_delete_rule(rule_name): @@ -109,18 +109,13 @@ def __do_delete_rules(response): def delete_event_rules(): - response = events_client.list_rules( - NamePrefix=f'{solution_name}-Controller-', - Limit=100, - ) - __do_delete_rules(response) + next_token = None while True: - if "NextToken" not in response: - break - next_token = response["NextToken"] - response = events_client.list_rules( - NamePrefix=f'{solution_name}-Controller-', - Limit=100, - NextToken=next_token - ) + if next_token: + response = events_client.list_rules(NamePrefix=f'{solution_name}-Controller-', Limit=100, NextToken=next_token) + else: + response = events_client.list_rules(NamePrefix=f'{solution_name}-Controller-', Limit=100) __do_delete_rules(response) + next_token = response.get('NextToken') + if not next_token: + break diff --git a/source/constructs/lib/admin/glue-stack.ts b/source/constructs/lib/admin/glue-stack.ts index 953eba84..cb9da6d5 100644 --- a/source/constructs/lib/admin/glue-stack.ts +++ b/source/constructs/lib/admin/glue-stack.ts @@ -56,6 +56,22 @@ export class GlueStack extends Construct { destinationKeyPrefix: 'job/script', }); + new S3Deployment.BucketDeployment(this, 'BatchCreateDatasourceTemplate', { + memoryLimit: 512, + ephemeralStorageSize: Size.mebibytes(512), + sources: [S3Deployment.Source.asset('config/batch_create/datasource/template')], + destinationBucket: props.bucket, + destinationKeyPrefix: 'batch_create/datasource/template', + }); + + new S3Deployment.BucketDeployment(this, 'BatchCreateIdentifierTemplate', { + memoryLimit: 512, + ephemeralStorageSize: Size.mebibytes(512), + sources: [S3Deployment.Source.asset('config/batch_create/identifier/template')], + destinationBucket: props.bucket, + destinationKeyPrefix: 'batch_create/identifier/template', + }); + // When upgrading, files with template as the prefix will be deleted // Therefore, the initial template file will no longer be deployed. // new S3Deployment.BucketDeployment(this, 'DeploymentTemplate', { @@ -85,8 +101,8 @@ export class GlueStack extends Construct { name: table_name, description: 'Save SDPS glue detection data', partitionKeys: [{ name: 'year', type: 'smallint' }, - { name: 'month', type: 'smallint' }, - { name: 'day', type: 'smallint' }], + { name: 'month', type: 'smallint' }, + { name: 'day', type: 'smallint' }], parameters: { classification: 'parquet', has_encrypted_data: 'Unencrypted', @@ -107,6 +123,7 @@ export class GlueStack extends Construct { { name: 'identifiers', type: 'array>' }, { name: 'sample_data', type: 'array' }, { name: 'table_size', type: 'int' }, + { name: 'location', type: 'string' }, { name: 's3_location', type: 'string' }, { name: 's3_bucket', type: 'string' }, { name: 'rds_instance_id', type: 'string' }, @@ -205,12 +222,15 @@ export class GlueStack extends Construct { })); const noramlStatement = new PolicyStatement({ effect: Effect.ALLOW, - actions: ['glue:GetTable', + actions: [ + 'glue:GetTable', 'glue:BatchCreatePartition', - 'glue:CreatePartition'], + 'glue:CreatePartition', + 'glue:TagResource', + ], resources: [`arn:${Aws.PARTITION}:glue:*:${Aws.ACCOUNT_ID}:table/${SolutionInfo.SOLUTION_GLUE_DATABASE}/*`, - `arn:${Aws.PARTITION}:glue:*:${Aws.ACCOUNT_ID}:database/${SolutionInfo.SOLUTION_GLUE_DATABASE}`, - `arn:${Aws.PARTITION}:glue:*:${Aws.ACCOUNT_ID}:catalog`], + `arn:${Aws.PARTITION}:glue:*:${Aws.ACCOUNT_ID}:database/${SolutionInfo.SOLUTION_GLUE_DATABASE}`, + `arn:${Aws.PARTITION}:glue:*:${Aws.ACCOUNT_ID}:catalog`], }); addPartitionRole.addToPolicy(noramlStatement); diff --git a/source/constructs/lib/admin/rds-stack.ts b/source/constructs/lib/admin/rds-stack.ts index fe99ae7a..101232b9 100644 --- a/source/constructs/lib/admin/rds-stack.ts +++ b/source/constructs/lib/admin/rds-stack.ts @@ -15,6 +15,7 @@ import * as path from 'path'; import { CustomResource, Duration, RemovalPolicy, + Tags, } from 'aws-cdk-lib'; import { InstanceClass, @@ -66,6 +67,7 @@ export class RdsStack extends Construct { vpc: props.vpc, description: 'connet to RDS', }); + Tags.of(this.clientSecurityGroup).add(SolutionInfo.TAG_NAME, `${SolutionInfo.SOLUTION_NAME}-RDS Client`); const rdsSecurityGroup = new SecurityGroup(this, 'RDSSecurityGroup', { // securityGroupName: 'RDS', vpc: props.vpc, @@ -76,6 +78,7 @@ export class RdsStack extends Construct { Port.tcp(this.dbPort), 'Allow RDS client', ); + Tags.of(rdsSecurityGroup).add(SolutionInfo.TAG_NAME, `${SolutionInfo.SOLUTION_NAME}-RDS`); const secretName = `${SolutionInfo.SOLUTION_NAME}`; const dbSecret = new DatabaseSecret(this, 'Secret', { @@ -113,7 +116,7 @@ export class RdsStack extends Construct { removalPolicy: RemovalPolicy.DESTROY, multiAz: true, deletionProtection: false, - caCertificate: CaCertificate.RDS_CA_RDS4096_G1, + caCertificate: CaCertificate.RDS_CA_RDS2048_G1, }); new SecretRotation(this, 'SecretRotation', { diff --git a/source/constructs/lib/admin/vpc-stack.ts b/source/constructs/lib/admin/vpc-stack.ts index c6d44920..4a1e7e4e 100644 --- a/source/constructs/lib/admin/vpc-stack.ts +++ b/source/constructs/lib/admin/vpc-stack.ts @@ -68,6 +68,7 @@ export class VpcStack extends Construct { public publicSubnet2 = ''; public privateSubnet1 = ''; public privateSubnet2 = ''; + readonly customDBSecurityGroup: SecurityGroup; constructor(scope: Construct, id: string, props?: VpcProps) { super(scope, id); @@ -78,14 +79,15 @@ export class VpcStack extends Construct { this.createVpc(props); } // Create CustomDB Security Group - const securityGroup = new SecurityGroup(this, 'CustomDBSecurityGroup', { + this.customDBSecurityGroup = new SecurityGroup(this, 'CustomDBSecurityGroup', { vpc: this.vpc, securityGroupName: 'SDPS-CustomDB', description: 'Allow all TCP ingress traffic', }); // Allow ingress on all TCP ports from the same security group - securityGroup.addIngressRule(securityGroup, Port.allTcp()); + this.customDBSecurityGroup.addIngressRule(this.customDBSecurityGroup, Port.allTcp()); + Tags.of(this.customDBSecurityGroup).add(SolutionInfo.TAG_NAME, `${SolutionInfo.SOLUTION_NAME}-ConnectToCustomDatabase`); } private createVpc(props?: VpcProps) { diff --git a/source/constructs/lib/agent-stack.ts b/source/constructs/lib/agent-stack.ts index a90462b5..86d174ba 100755 --- a/source/constructs/lib/agent-stack.ts +++ b/source/constructs/lib/agent-stack.ts @@ -18,7 +18,7 @@ import { AgentRoleStack } from './agent/AgentRole-stack'; import { CrawlerEventbridgeStack } from './agent/CrawlerEventbridge-stack'; import { DeleteAgentResourcesStack } from './agent/DeleteAgentResources-stack'; import { DiscoveryJobStack } from './agent/DiscoveryJob-stack'; -import { RenameResourcesStack } from './agent/RenameResources-stack'; +// import { RenameResourcesStack } from './agent/RenameResources-stack'; import { BucketStack } from './common/bucket-stack'; import { Parameter } from './common/parameter'; import { SolutionInfo } from './common/solution-info'; @@ -54,9 +54,9 @@ export class AgentStack extends Stack { queueName: `${SolutionInfo.SOLUTION_NAME}-AutoSyncData`, }); - new RenameResourcesStack(scope, 'RenameResources', { - adminAccountId: adminAccountId, - }); + // new RenameResourcesStack(scope, 'RenameResources', { + // adminAccountId: adminAccountId, + // }); } constructor(scope: Construct, id: string, props?: StackProps) { diff --git a/source/constructs/lib/agent/AgentRole-stack.ts b/source/constructs/lib/agent/AgentRole-stack.ts index ce32d5bc..fc3b1dbf 100755 --- a/source/constructs/lib/agent/AgentRole-stack.ts +++ b/source/constructs/lib/agent/AgentRole-stack.ts @@ -38,9 +38,6 @@ export class AgentRoleStack extends Construct { new iam.PolicyStatement({ actions: [ 'iam:PassRole', - 'ec2:CreateVpcEndpoint', - 'ec2:DeleteVpcEndpoints', - 'ec2:DeleteNetworkInterface', 'ec2:DescribeNatGateways', 'ec2:DescribeVpcEndpoints', 'ec2:DescribeRouteTables', @@ -48,35 +45,34 @@ export class AgentRoleStack extends Construct { 'secretsmanager:GetSecretValue', 'kms:Decrypt', 'kms:DescribeKey', - 'glue:GetCrawler', - 'glue:GetCrawlers', - 'glue:GetClassifier', - 'glue:GetClassifiers', 'glue:CheckSchemaVersionValidity', 'glue:CreateClassifier', - 'glue:GetSecurityConfiguration', - 'glue:GetSecurityConfigurations', - 'glue:StartCrawler', - 'glue:StopCrawler', - 'glue:GetConnection', - 'glue:GetConnections', - 'glue:UpdateConnection', + 'glue:Get*', + 'glue:BatchGet*', + 'lakeformation:*', + 's3:List*', + 's3:Get*', ], resources: ['*'], }), new iam.PolicyStatement({ actions: [ 'glue:CreateJob', - 'glue:UpdateJob', 'glue:DeleteJob', - 'glue:GetJob', + 'glue:UpdateJob', 'glue:CreateConnection', 'glue:DeleteConnection', + 'glue:UpdateConnection', 'glue:BatchDeleteTable', 'glue:CreateCrawler', + 'glue:DeleteCrawler', 'glue:UpdateCrawler', + 'glue:StartCrawler', 'glue:StopCrawler', - 'glue:DeleteCrawler', + 'glue:CreateDatabase', + 'glue:DeleteDatabase', + 'glue:UpdateDatabase', + 'glue:TagResource', 'lambda:CreateFunction', 'lambda:DeleteFunction', 'lambda:GetFunction', @@ -85,17 +81,18 @@ export class AgentRoleStack extends Construct { 'states:GetExecutionHistory', 'states:StartExecution', 'states:StopExecution', - 'states:DescribeStateMachine', - 'states:UpdateStateMachine', - 'states:DeleteStateMachine', 'states:CreateStateMachine', + 'states:DeleteStateMachine', + 'states:UpdateStateMachine', + 'states:DescribeStateMachine', 'states:TagResource', 'states:ListTagsForResource', ], resources: [ `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:catalog`, `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:database/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*`, - `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:table/${SolutionInfo.SOLUTION_NAME}-*/*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:table/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*/*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:userDefinedFunction/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*/*`, `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:job/${SolutionInfo.SOLUTION_NAME}-*`, `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:connection/${SolutionInfo.SOLUTION_NAME}-*`, `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:crawler/${SolutionInfo.SOLUTION_NAME}-*`, @@ -106,22 +103,6 @@ export class AgentRoleStack extends Construct { }), ], })); - // Copy from AmazonS3ReadOnlyAccess, do not modify - roleForAdmin.attachInlinePolicy(new iam.Policy(this, 'AmazonS3ReadOnlyAccessPolicy', { - policyName: 'AmazonS3ReadOnlyAccessPolicy', - statements: [ - new iam.PolicyStatement({ - effect: iam.Effect.ALLOW, - actions: [ - 's3:Get*', - 's3:List*', - 's3-object-lambda:Get*', - 's3-object-lambda:List*', - ], - resources: ['*'], - }), - ], - })); // Copy from AmazonRDSReadOnlyAccess, do not modify roleForAdmin.attachInlinePolicy(new iam.Policy(this, 'AmazonRDSReadOnlyAccessPolicy', { policyName: 'AmazonRDSReadOnlyAccessPolicy', @@ -152,129 +133,78 @@ export class AgentRoleStack extends Construct { }), ], })); - // Copy from AWSLakeFormationDataAdmin, do not modify - roleForAdmin.attachInlinePolicy(new iam.Policy(this, 'AWSLakeFormationDataAdminPolicy', { - policyName: 'AWSLakeFormationDataAdmin', + + const glueDetectionJobRole = new iam.Role(this, 'GlueDetectionJobRole', { + assumedBy: new iam.ServicePrincipal('glue.amazonaws.com'), + roleName: `${SolutionInfo.SOLUTION_NAME}GlueDetectionJobRole-${Aws.REGION}`, //Name must be specified + }); + glueDetectionJobRole.attachInlinePolicy(new iam.Policy(this, 'GlueDetectionJobPolicy', { + policyName: `${SolutionInfo.SOLUTION_NAME}GlueDetectionJobPolicy`, statements: [ new iam.PolicyStatement({ - effect: iam.Effect.ALLOW, actions: [ - 'lakeformation:*', - 'cloudtrail:DescribeTrails', - 'cloudtrail:LookupEvents', - 'glue:GetDatabase', - 'glue:GetDatabases', - 'glue:CreateDatabase', - 'glue:UpdateDatabase', - 'glue:DeleteDatabase', - 'glue:GetConnections', - 'glue:SearchTables', - 'glue:GetTable', - 'glue:CreateTable', - 'glue:UpdateTable', - 'glue:DeleteTable', - 'glue:GetTableVersions', - 'glue:GetPartitions', - 'glue:GetTables', - 'glue:GetWorkflow', - 'glue:ListWorkflows', - 'glue:BatchGetWorkflows', - 'glue:DeleteWorkflow', - 'glue:GetWorkflowRuns', - 'glue:StartWorkflowRun', - 'glue:GetWorkflow', 's3:ListBucket', - 's3:GetBucketLocation', - 's3:ListAllMyBuckets', - 's3:GetBucketAcl', - 'iam:ListUsers', - 'iam:ListRoles', - 'iam:GetRole', - 'iam:GetRolePolicy', + 's3:GetObject', + 's3:PutObject', // Put object in Admin Bucket.When installing Agent Stack independently, do not know the Admin Bucket name. + 'glue:Get*', + 'glue:BatchGet*', + 'ec2:DescribeSubnets', + 'ec2:DescribeSecurityGroups', + 'ec2:DescribeVpcEndpoints', + 'ec2:DescribeRouteTables', + 'ec2:CreateNetworkInterface', + 'ec2:DescribeNetworkInterfaces', + 'ec2:DeleteNetworkInterface', + 'secretsmanager:GetSecretValue', + 'kms:Decrypt', + 'kms:DescribeKey', ], resources: ['*'], }), new iam.PolicyStatement({ - effect: iam.Effect.DENY, actions: [ - 'lakeformation:PutDataLakeSettings', + 'glue:CreateTable', + 'glue:DeleteTable', + 'glue:DeleteTableVersion', + 'glue:UpdateTable', + 'glue:CreateDatabase', + 'glue:UpdatePartition', + 'glue:BatchCreatePartition', + 'glue:BatchUpdatePartition', + 'glue:TagResource', + ], + resources: [ + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:catalog`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:database/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:table/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*/*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:userDefinedFunction/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*/*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:job/${SolutionInfo.SOLUTION_NAME}-*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:connection/${SolutionInfo.SOLUTION_NAME}-*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:crawler/${SolutionInfo.SOLUTION_NAME}-*`, + ], + }), + new iam.PolicyStatement({ + actions: [ + 'logs:CreateLogGroup', + 'logs:CreateLogStream', + 'logs:PutLogEvents', + ], + resources: [`arn:${Aws.PARTITION}:logs:*:*:/aws-glue/*`], + }), + new iam.PolicyStatement({ + actions: [ + 'ec2:CreateTags', + 'ec2:DeleteTags', + ], + resources: [ + `arn:${Aws.PARTITION}:ec2:*:*:network-interface/*`, + `arn:${Aws.PARTITION}:ec2:*:*:security-group/*`, + `arn:${Aws.PARTITION}:ec2:*:*:instance/*`, ], - resources: ['*'], }), ], })); - const glueDetectionJobRole = new iam.Role(this, 'GlueDetectionJobRole', { - assumedBy: new iam.ServicePrincipal('glue.amazonaws.com'), - roleName: `${SolutionInfo.SOLUTION_NAME}GlueDetectionJobRole-${Aws.REGION}`, //Name must be specified - }); - glueDetectionJobRole.attachInlinePolicy(new iam.Policy(this, 'GlueDetectionJobPolicy', { - policyName: `${SolutionInfo.SOLUTION_NAME}GlueDetectionJobPolicy`, - statements: [new iam.PolicyStatement({ - actions: [ - 'lakeformation:*', - 's3:ListBucket', - 's3:GetObject', - 's3:PutObject', - 'glue:CreateTable', - 'glue:UpdateTable', - 'glue:GetDatabase', - 'glue:GetTables', - 'glue:GetTable', - 'glue:CreateDatabase', - 'glue:DeleteTable', - 'glue:DeleteTableVersion', - 'glue:DeleteDatabase', - 'glue:GetConnection', - 'glue:GetPartition', - 'glue:GetPartitions', - 'glue:UpdatePartition', - 'glue:BatchCreatePartition', - 'glue:BatchUpdatePartition', - 'glue:BatchGetPartition', - 'glue:BatchGetCustomEntityTypes', - 'ec2:DescribeVpcEndpoints', - 'ec2:DescribeRouteTables', - 'ec2:CreateNetworkInterface', - 'ec2:DeleteNetworkInterface', - 'ec2:DescribeNetworkInterfaces', - 'ec2:DescribeSecurityGroups', - 'ec2:DescribeSubnets', - 'ec2:DescribeVpcAttribute', - 'secretsmanager:GetSecretValue', - 'kms:Decrypt', - 'kms:DescribeKey', - ], - resources: ['*'], - }), - new iam.PolicyStatement({ - actions: [ - 'logs:CreateLogGroup', - 'logs:CreateLogStream', - 'logs:PutLogEvents', - ], - resources: [`arn:${Aws.PARTITION}:logs:*:*:/aws-glue/*`], - }), - new iam.PolicyStatement({ - actions: [ - 'ec2:CreateTags', - 'ec2:DeleteTags', - ], - resources: [ - `arn:${Aws.PARTITION}:ec2:*:*:network-interface/*`, - `arn:${Aws.PARTITION}:ec2:*:*:security-group/*`, - `arn:${Aws.PARTITION}:ec2:*:*:instance/*`, - ], - conditions: { - 'ForAllValues:StringEquals': { - 'aws:TagKeys': [ - 'aws-glue-service-resource', - ], - }, - }, - })], - })); - const lambdaRdsRole = new iam.Role(this, 'LambdaRdsRole', { assumedBy: new iam.ServicePrincipal('lambda.amazonaws.com'), roleName: `${SolutionInfo.SOLUTION_NAME}LambdaRdsRole-${Aws.REGION}`, //Name must be specified diff --git a/source/constructs/lib/agent/DeleteAgentResources-stack.ts b/source/constructs/lib/agent/DeleteAgentResources-stack.ts index 0babce9b..a198bfab 100644 --- a/source/constructs/lib/agent/DeleteAgentResources-stack.ts +++ b/source/constructs/lib/agent/DeleteAgentResources-stack.ts @@ -70,28 +70,29 @@ export class DeleteAgentResourcesStack extends Construct { new PolicyStatement({ effect: Effect.ALLOW, actions: [ - 'glue:GetDatabase', - 'glue:GetDatabases', - 'glue:DeleteDatabase', - 'glue:GetConnections', - 'glue:DeleteConnection', - 'glue:GetTable', - 'glue:GetTables', - 'glue:DeleteTable', - 'glue:GetCrawler', 'glue:ListCrawlers', - 'glue:DeleteCrawler', 'glue:ListJobs', - 'glue:DeleteJob', ], resources: ['*'], }), new PolicyStatement({ effect: Effect.ALLOW, actions: [ + 'glue:DeleteDatabase', + 'glue:DeleteConnection', + 'glue:GetCrawler', + 'glue:DeleteCrawler', + 'glue:DeleteJob', 'sqs:SendMessage', ], resources: [ + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:catalog`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:database/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:table/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*/*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:userDefinedFunction/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*/*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:job/${SolutionInfo.SOLUTION_NAME}-*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:connection/${SolutionInfo.SOLUTION_NAME}-*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:crawler/${SolutionInfo.SOLUTION_NAME}-*`, `arn:${Aws.PARTITION}:sqs:${Aws.REGION}:${props.adminAccountId}:${props.queueName}`, ], }), diff --git a/source/constructs/lib/agent/DiscoveryJob-stack.ts b/source/constructs/lib/agent/DiscoveryJob-stack.ts index 89d9861d..3af8574d 100644 --- a/source/constructs/lib/agent/DiscoveryJob-stack.ts +++ b/source/constructs/lib/agent/DiscoveryJob-stack.ts @@ -50,100 +50,13 @@ export class DiscoveryJobStack extends Construct { this.createSplitJobFunction(); this.createUnstructuredCrawlerFunction(props); - this.createUnstructuredParserRole(); + this.createUnstructuredParserRole(props); const discoveryJobRole = new Role(this, 'DiscoveryJobRole', { assumedBy: new ServicePrincipal('states.amazonaws.com'), roleName: `${SolutionInfo.SOLUTION_NAME}DiscoveryJobRole-${Aws.REGION}`, //Name must be specified }); - // Copy from AWSGlueServiceRole, do not modify - discoveryJobRole.attachInlinePolicy(new Policy(this, 'AWSGlueServicePolicy', { - policyName: 'AWSGlueServicePolicy', - statements: [ - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'glue:*', - 's3:GetBucketLocation', - 's3:ListBucket', - 's3:ListAllMyBuckets', - 's3:GetBucketAcl', - 'ec2:DescribeVpcEndpoints', - 'ec2:DescribeRouteTables', - 'ec2:CreateNetworkInterface', - 'ec2:DeleteNetworkInterface', - 'ec2:DescribeNetworkInterfaces', - 'ec2:DescribeSecurityGroups', - 'ec2:DescribeSubnets', - 'ec2:DescribeVpcAttribute', - 'iam:ListRolePolicies', - 'iam:GetRole', - 'iam:GetRolePolicy', - 'cloudwatch:PutMetricData', - ], - resources: ['*'], - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: ['s3:CreateBucket'], - resources: [ - `arn:${Aws.PARTITION}:s3:::aws-glue-*`, - ], - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 's3:GetObject', - 's3:PutObject', - 's3:DeleteObject', - ], - resources: [ - `arn:${Aws.PARTITION}:s3:::aws-glue-*/*`, - `arn:${Aws.PARTITION}:s3:::*/*aws-glue-*/*`, - ], - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 's3:GetObject', - ], - resources: [ - `arn:${Aws.PARTITION}:s3:::crawler-public*`, - `arn:${Aws.PARTITION}:s3:::aws-glue-*`, - ], - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'logs:CreateLogGroup', - 'logs:CreateLogStream', - 'logs:PutLogEvents', - ], - resources: [ - `arn:${Aws.PARTITION}:logs:*:*:/aws-glue/*`, - ], - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'ec2:CreateTags', - 'ec2:DeleteTags', - ], - conditions: { - StringEquals: { - 'aws:TagKeys': 'aws-glue-service-resource', - }, - }, - resources: [ - `arn:${Aws.PARTITION}:ec2:*:*:network-interface/*`, - `arn:${Aws.PARTITION}:ec2:*:*:security-group/*`, - `arn:${Aws.PARTITION}:ec2:*:*:instance/*`, - ], - }), - ], - })); - const discoveryJobPolicy = new Policy(this, 'DiscoveryJobPolicy', { policyName: `${SolutionInfo.SOLUTION_NAME}DiscoveryJobPolicy`, statements: [ @@ -158,6 +71,13 @@ export class DiscoveryJobStack extends Construct { 'events:PutRule', 'events:DescribeRule', 'iam:PassRole', + 'ec2:DescribeSubnets', + 'ec2:DescribeSecurityGroups', + 'ec2:DescribeVpcEndpoints', + 'ec2:DescribeRouteTables', + 'ec2:CreateNetworkInterface', + 'ec2:DescribeNetworkInterfaces', + 'ec2:DeleteNetworkInterface', ], resources: ['*'], }), @@ -169,12 +89,22 @@ export class DiscoveryJobStack extends Construct { 'lambda:InvokeFunction', 'ssm:GetParameter', 'sqs:SendMessage', + 'glue:GetCrawler', + 'glue:StartCrawler', + 'glue:StopCrawler', + 'glue:StartJobRun', + 'glue:BatchStopJobRun', + 'glue:GetJobRun', + 'glue:GetJobRuns', + 'glue:TagResource', ], resources: [ `arn:${Aws.PARTITION}:sagemaker:${Aws.REGION}:${Aws.ACCOUNT_ID}:processing-job/${SolutionInfo.SOLUTION_NAME}-*`, `arn:${Aws.PARTITION}:lambda:${Aws.REGION}:${Aws.ACCOUNT_ID}:function:${SolutionInfo.SOLUTION_NAME}-*`, `arn:${Aws.PARTITION}:ssm:${Aws.REGION}:${Aws.ACCOUNT_ID}:parameter/${SolutionInfo.SOLUTION_NAME}-AgentBucketName`, `arn:${Aws.PARTITION}:sqs:${Aws.REGION}:${props.adminAccountId}:${SolutionInfo.SOLUTION_NAME}-DiscoveryJob`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:job/${SolutionInfo.SOLUTION_NAME}-*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:crawler/${SolutionInfo.SOLUTION_NAME}-*`, ], }), ], @@ -242,7 +172,11 @@ export class DiscoveryJobStack extends Construct { actions: [ 'glue:GetTables', ], - resources: ['*'], + resources: [ + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:catalog`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:database/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:table/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*/*`, + ], }), ], })); @@ -324,275 +258,48 @@ export class DiscoveryJobStack extends Construct { }); } - private createUnstructuredParserRole() { + private createUnstructuredParserRole(props: DiscoveryJobProps) { const unstructuredParserRole = new Role(this, 'UnstructuredParserRole', { roleName: `${SolutionInfo.SOLUTION_NAME}UnstructuredParserRole-${Aws.REGION}`, //Name must be specified assumedBy: new ServicePrincipal('sagemaker.amazonaws.com'), }); - // Copy from AmazonS3FullAccess, do not modify - unstructuredParserRole.attachInlinePolicy(new Policy(this, 'AmazonS3FullAccessPolicy', { - policyName: 'AmazonS3FullAccessPolicy', - statements: [ - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 's3:*', - 's3-object-lambda:*', - ], - resources: ['*'], - }), - ], - })); - // Copy from AmazonSageMakerFullAccess, do not modify - unstructuredParserRole.attachInlinePolicy(new Policy(this, 'AmazonSageMakerFullAccessPolicy', { - policyName: 'AmazonSageMakerFullAccessPolicy', + unstructuredParserRole.attachInlinePolicy(new Policy(this, 'UnstructuredParsePolicy', { + policyName: 'UnstructuredParsePolicy', statements: [ new PolicyStatement({ effect: Effect.ALLOW, actions: [ - 'sagemaker:*', - ], - notResources: [ - `arn:${Aws.PARTITION}:sagemaker:*:*:domain/*`, - `arn:${Aws.PARTITION}:sagemaker:*:*:user-profile/*`, - `arn:${Aws.PARTITION}:sagemaker:*:*:app/*`, - `arn:${Aws.PARTITION}:sagemaker:*:*:flow-definition/*`, - ], - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'sagemaker:CreatePresignedDomainUrl', - 'sagemaker:DescribeDomain', - 'sagemaker:ListDomains', - 'sagemaker:DescribeUserProfile', - 'sagemaker:ListUserProfiles', - 'sagemaker:*App', - 'sagemaker:ListApps', - ], - resources: ['*'], - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'sagemaker:*', - ], - resources: [`arn:${Aws.PARTITION}:sagemaker:*:*:flow-definition/*`], - conditions: { - StringEqualsIfExists: { - 'sagemaker:WorkteamType': [ - 'private-crowd', - 'vendor-crowd', - ], - }, - }, - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'application-autoscaling:DeleteScalingPolicy', - 'application-autoscaling:DeleteScheduledAction', - 'application-autoscaling:DeregisterScalableTarget', - 'application-autoscaling:DescribeScalableTargets', - 'application-autoscaling:DescribeScalingActivities', - 'application-autoscaling:DescribeScalingPolicies', - 'application-autoscaling:DescribeScheduledActions', - 'application-autoscaling:PutScalingPolicy', - 'application-autoscaling:PutScheduledAction', - 'application-autoscaling:RegisterScalableTarget', - 'aws-marketplace:ViewSubscriptions', - 'cloudformation:GetTemplateSummary', - 'cloudwatch:DeleteAlarms', - 'cloudwatch:DescribeAlarms', - 'cloudwatch:GetMetricData', - 'cloudwatch:GetMetricStatistics', - 'cloudwatch:ListMetrics', - 'cloudwatch:PutMetricAlarm', - 'cloudwatch:PutMetricData', - 'codecommit:BatchGetRepositories', - 'codecommit:CreateRepository', - 'codecommit:GetRepository', - 'codecommit:List*', - 'cognito-idp:AdminAddUserToGroup', - 'cognito-idp:AdminCreateUser', - 'cognito-idp:AdminDeleteUser', - 'cognito-idp:AdminDisableUser', - 'cognito-idp:AdminEnableUser', - 'cognito-idp:AdminRemoveUserFromGroup', - 'cognito-idp:CreateGroup', - 'cognito-idp:CreateUserPool', - 'cognito-idp:CreateUserPoolClient', - 'cognito-idp:CreateUserPoolDomain', - 'cognito-idp:DescribeUserPool', - 'cognito-idp:DescribeUserPoolClient', - 'cognito-idp:List*', - 'cognito-idp:UpdateUserPool', - 'cognito-idp:UpdateUserPoolClient', - 'ec2:CreateNetworkInterface', - 'ec2:CreateNetworkInterfacePermission', - 'ec2:CreateVpcEndpoint', - 'ec2:DeleteNetworkInterface', - 'ec2:DeleteNetworkInterfacePermission', - 'ec2:DescribeDhcpOptions', - 'ec2:DescribeNetworkInterfaces', - 'ec2:DescribeRouteTables', - 'ec2:DescribeSecurityGroups', - 'ec2:DescribeSubnets', - 'ec2:DescribeVpcEndpoints', - 'ec2:DescribeVpcs', + 's3:ListBucket', + 's3:GetObject', 'ecr:BatchCheckLayerAvailability', 'ecr:BatchGetImage', - 'ecr:CreateRepository', - 'ecr:Describe*', 'ecr:GetAuthorizationToken', 'ecr:GetDownloadUrlForLayer', - 'ecr:StartImageScan', - 'elastic-inference:Connect', - 'elasticfilesystem:DescribeFileSystems', - 'elasticfilesystem:DescribeMountTargets', - 'fsx:DescribeFileSystems', - 'glue:CreateJob', - 'glue:DeleteJob', - 'glue:GetJob*', - 'glue:GetTable*', - 'glue:GetWorkflowRun', - 'glue:ResetJobBookmark', - 'glue:StartJobRun', - 'glue:StartWorkflowRun', - 'glue:UpdateJob', - 'groundtruthlabeling:*', - 'iam:ListRoles', - 'kms:DescribeKey', - 'kms:ListAliases', - 'lambda:ListFunctions', - 'logs:CreateLogDelivery', 'logs:CreateLogGroup', 'logs:CreateLogStream', - 'logs:DeleteLogDelivery', - 'logs:Describe*', - 'logs:GetLogDelivery', - 'logs:GetLogEvents', - 'logs:ListLogDeliveries', 'logs:PutLogEvents', - 'logs:PutResourcePolicy', - 'logs:UpdateLogDelivery', - 'robomaker:CreateSimulationApplication', - 'robomaker:DescribeSimulationApplication', - 'robomaker:DeleteSimulationApplication', - 'robomaker:CreateSimulationJob', - 'robomaker:DescribeSimulationJob', - 'robomaker:CancelSimulationJob', - 'secretsmanager:ListSecrets', - 'servicecatalog:Describe*', - 'servicecatalog:List*', - 'servicecatalog:ScanProvisionedProducts', - 'servicecatalog:SearchProducts', - 'servicecatalog:SearchProvisionedProducts', - 'sns:ListTopics', - 'tag:GetResources', - ], - resources: ['*'], - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'ecr:SetRepositoryPolicy', - 'ecr:CompleteLayerUpload', - 'ecr:BatchDeleteImage', - 'ecr:UploadLayerPart', - 'ecr:DeleteRepositoryPolicy', - 'ecr:InitiateLayerUpload', - 'ecr:DeleteRepository', - 'ecr:PutImage', - ], - resources: [`arn:${Aws.PARTITION}:ecr:*:*:repository/*sagemaker*`], - }), - // From here down, there is no copy - ], - })); - // Copy from AWSGlueServiceRole, do not modify - unstructuredParserRole.attachInlinePolicy(new Policy(this, 'AWSGlueServicePolicy2', { - policyName: 'AWSGlueServicePolicy', - statements: [ - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'glue:*', - 's3:GetBucketLocation', - 's3:ListBucket', - 's3:ListAllMyBuckets', - 's3:GetBucketAcl', - 'ec2:DescribeVpcEndpoints', - 'ec2:DescribeRouteTables', - 'ec2:CreateNetworkInterface', - 'ec2:DeleteNetworkInterface', - 'ec2:DescribeNetworkInterfaces', - 'ec2:DescribeSecurityGroups', - 'ec2:DescribeSubnets', - 'ec2:DescribeVpcAttribute', - 'iam:ListRolePolicies', - 'iam:GetRole', - 'iam:GetRolePolicy', - 'cloudwatch:PutMetricData', ], resources: ['*'], }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: ['s3:CreateBucket'], - resources: [ - `arn:${Aws.PARTITION}:s3:::aws-glue-*`, - ], - }), new PolicyStatement({ effect: Effect.ALLOW, actions: [ - 's3:GetObject', + 'glue:CreateDatabase', + 'glue:GetDatabase', + 'glue:CreateTable', + 'glue:DeleteTable', + 'glue:DeleteTableVersion', + 'glue:UpdateTable', + 'glue:GetTables', + 'glue:GetTable', 's3:PutObject', - 's3:DeleteObject', - ], - resources: [ - `arn:${Aws.PARTITION}:s3:::aws-glue-*/*`, - `arn:${Aws.PARTITION}:s3:::*/*aws-glue-*/*`, - ], - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 's3:GetObject', - ], - resources: [ - `arn:${Aws.PARTITION}:s3:::crawler-public*`, - `arn:${Aws.PARTITION}:s3:::aws-glue-*`, - ], - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'logs:CreateLogGroup', - 'logs:CreateLogStream', - 'logs:PutLogEvents', - ], - resources: [ - `arn:${Aws.PARTITION}:logs:*:*:/aws-glue/*`, - ], - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'ec2:CreateTags', - 'ec2:DeleteTags', ], - conditions: { - StringEquals: { - 'aws:TagKeys': 'aws-glue-service-resource', - }, - }, resources: [ - `arn:${Aws.PARTITION}:ec2:*:*:network-interface/*`, - `arn:${Aws.PARTITION}:ec2:*:*:security-group/*`, - `arn:${Aws.PARTITION}:ec2:*:*:instance/*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:catalog`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:database/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*`, + `arn:${Aws.PARTITION}:glue:${Aws.REGION}:${Aws.ACCOUNT_ID}:table/${SolutionInfo.SOLUTION_NAME.toLowerCase()}-*/*`, + `arn:${Aws.PARTITION}:s3:::${props.agentBucketName}/*`, ], }), ], diff --git a/source/constructs/lib/agent/DiscoveryJob.json b/source/constructs/lib/agent/DiscoveryJob.json index d16c45e7..ec82f1fa 100644 --- a/source/constructs/lib/agent/DiscoveryJob.json +++ b/source/constructs/lib/agent/DiscoveryJob.json @@ -453,7 +453,7 @@ "ProcessingResources": { "ClusterConfig": { "InstanceCount": 1, - "InstanceType": "ml.m5.2xlarge", + "InstanceType": "ml.c5.24xlarge", "VolumeSizeInGB": 1 } }, diff --git a/source/constructs/lib/agent/delete-agent-resources/delete_agent_resources.py b/source/constructs/lib/agent/delete-agent-resources/delete_agent_resources.py index d9100609..168732b3 100644 --- a/source/constructs/lib/agent/delete-agent-resources/delete_agent_resources.py +++ b/source/constructs/lib/agent/delete-agent-resources/delete_agent_resources.py @@ -98,9 +98,12 @@ def on_delete(event): def cleanup_crawlers(): crawlers = [] - next_page = '' + next_token = None while True: - response = glue.list_crawlers(NextToken=next_page, Tags={'Owner':solution_name, 'AdminAccountId': admin_account_id}) + if next_token: + response = glue.list_crawlers(Tags={'Owner':solution_name, 'AdminAccountId': admin_account_id}, NextToken=next_token) + else: + response = glue.list_crawlers(Tags={'Owner':solution_name, 'AdminAccountId': admin_account_id}) for crawler_name in response['CrawlerNames']: if not crawler_name.startswith(solution_name + "-"): continue @@ -111,13 +114,13 @@ def cleanup_crawlers(): if database_name.startswith(f"{solution_name}-s3-"): unstructured_database_name = database_name.replace(f"{solution_name}-s3-",f"{solution_name}-unstructured-",1) remove_database(unstructured_database_name) - if len(response_crawler['Crawler']['Targets']['JdbcTargets']) == 1: + if len(response_crawler['Crawler']['Targets']['JdbcTargets']) > 0: connection_name = response_crawler['Crawler']['Targets']['JdbcTargets'][0]['ConnectionName'] remove_jdbc_connection(connection_name) crawlers.append(crawler_name) - next_page = response.get('NextToken') - if next_page is None: + next_token = response.get('NextToken') + if not next_token: break return crawlers @@ -146,14 +149,14 @@ def remove_crawler(crawler_name: str): def clean_jobs(): - next_token = '' + next_token = None while True: - response = glue.list_jobs( - NextToken=next_token, - Tags={'Owner':solution_name, "AdminAccountId": admin_account_id} - ) + if next_token: + response = glue.list_jobs(Tags={'Owner':solution_name, "AdminAccountId": admin_account_id}, NextToken=next_token) + else: + response = glue.list_jobs(Tags={'Owner':solution_name, "AdminAccountId": admin_account_id}) for job in response["JobNames"]: glue.delete_job(JobName=job) next_token = response.get("NextToken") - if next_token is None: + if not next_token: break diff --git a/source/constructs/lib/agent/rename-resources/rename_resources.py b/source/constructs/lib/agent/rename-resources/rename_resources.py index 158ecf0e..69ef1a53 100644 --- a/source/constructs/lib/agent/rename-resources/rename_resources.py +++ b/source/constructs/lib/agent/rename-resources/rename_resources.py @@ -68,7 +68,6 @@ def send_response(event, response_status = "SUCCESS", reason = "OK"): def on_create(event): logger.info("Got create") - list_crawlers() def on_update(event): diff --git a/source/constructs/lib/agent/split-job/split_job.py b/source/constructs/lib/agent/split-job/split_job.py index 8ea9096e..a0bdc978 100644 --- a/source/constructs/lib/agent/split-job/split_job.py +++ b/source/constructs/lib/agent/split-job/split_job.py @@ -11,18 +11,18 @@ def get_table_count(glue_database_name, base_time): - next_token = "" + next_token = None table_count = 0 while True: - response = glue.get_tables( - # CatalogId=catalog_id, - DatabaseName=glue_database_name, - NextToken=next_token) + if next_token: + response = glue.get_tables(DatabaseName=glue_database_name, NextToken=next_token) + else: + response = glue.get_tables(DatabaseName=glue_database_name) for table in response['TableList']: if table.get('Parameters', {}).get('classification', '') != 'UNKNOWN' and table['UpdateTime'] > base_time: table_count += 1 next_token = response.get('NextToken') - if next_token is None: + if not next_token: break return table_count @@ -34,7 +34,7 @@ def divide_and_round_up(a, b): def get_job_number(event): if "JobNumber" in event: return event["JobNumber"] - if event['DatabaseType'] == "s3" or event['DatabaseType'] == "glue": + if event['DatabaseType'] in ["s3","glue"]: return 10 return 3 diff --git a/source/constructs/package.json b/source/constructs/package.json index d1f9c0d0..690dbf93 100755 --- a/source/constructs/package.json +++ b/source/constructs/package.json @@ -35,7 +35,7 @@ "typescript": "^4.8.4" }, "dependencies": { - "aws-cdk-lib": "^2.102.0", + "aws-cdk-lib": "^2.133.0", "cdk-bootstrapless-synthesizer": "^2.2.7", "cdk-nag": "^2.27.171" }, diff --git a/source/containers/document-pii-detection/main.py b/source/containers/document-pii-detection/main.py index 67024838..39d5be93 100644 --- a/source/containers/document-pii-detection/main.py +++ b/source/containers/document-pii-detection/main.py @@ -59,12 +59,12 @@ def split_dictionary(raw_dictionary, chunk_size=100): def get_previous_tables(glue_client, database_name): tables = [] - next_token = "" + next_token = None while True: - response = glue_client.get_tables( - DatabaseName=database_name, - NextToken=next_token - ) + if next_token: + response = glue_client.get_tables(DatabaseName=database_name, NextToken=next_token) + else: + response = glue_client.get_tables(DatabaseName=database_name) for table in response.get('TableList', []): if table.get('Parameters', {}).get('classification', '') != 'UNKNOWN': tables.append(table) diff --git a/source/portal/Dockerfile b/source/portal/Dockerfile index 1e1f90a7..f23c5b31 100644 --- a/source/portal/Dockerfile +++ b/source/portal/Dockerfile @@ -9,4 +9,5 @@ COPY --from=public.ecr.aws/awsguru/aws-lambda-adapter:0.6.0 /lambda-adapter /opt COPY --from=builder /tmp/frontend/build/ /usr/share/nginx/public/ COPY nginx-config/ /etc/nginx/ EXPOSE 8080 -CMD ["nginx", "-g", "daemon off;"] \ No newline at end of file +RUN chmod +x /etc/nginx/start_nginx.sh +ENTRYPOINT ["/etc/nginx/start_nginx.sh"] \ No newline at end of file diff --git a/source/portal/nginx-config/nginx.conf b/source/portal/nginx-config/nginx.conf index d518d4f0..5a803dfb 100644 --- a/source/portal/nginx-config/nginx.conf +++ b/source/portal/nginx-config/nginx.conf @@ -26,7 +26,9 @@ http { proxy_hide_header X-Powered-By; add_header X-Frame-Options deny; add_header X-Content-Type-Options nosniff; - add_header Content-Security-Policy "default-src 'self' *.okta.com *.authing.cn *.amazoncognito.com *.amazonaws.com; img-src 'self' blob: data:; style-src 'self' blob: data:; font-src 'self' blob: data:; script-src 'self';"; + # CSP is dynamically generated based on environmental variables during Docker startup + include /tmp/*.conf; + # add_header Content-Security-Policy "default-src 'self' *.okta.com *.authing.cn *.amazoncognito.com *.amazonaws.com; img-src 'self' blob: data:; style-src 'self' blob: data:; font-src 'self' blob: data:; script-src 'self';"; if ($request_method !~ ^(GET|HEAD|POST)$ ) { return 444; } diff --git a/source/portal/nginx-config/start_nginx.sh b/source/portal/nginx-config/start_nginx.sh new file mode 100755 index 00000000..f6ef314b --- /dev/null +++ b/source/portal/nginx-config/start_nginx.sh @@ -0,0 +1,36 @@ +#!/bin/sh + +wrote=false + +function write_domain_name(){ + # In the lambda environment, except for tmp, everything is read-only + wrote=true + csp=" add_header Content-Security-Policy \"default-src 'self' $1; img-src 'self' blob: data: ; style-src 'self' blob: data:; font-src 'self' blob: data:; script-src 'self';\";" + echo $csp > /tmp/CustomDomainName.conf +} +if [ -n "$CustomDomainName" ]; then + write_domain_name $CustomDomainName +else + if [ -n "$OidcIssuer" ]; then + # Due to the need to access external networks to obtain authorization_endpoint, the openid configuration is not parsed. + domain_name=$(echo "$OidcIssuer" | sed -n 's/^\(.*\:\/\/\)\([^\/]*\).*/\2/p') + build_in_domain_names="okta.com authing.cn amazoncognito.com amazonaws.com" + IFS=' ' + exist=false + for build_in_domain_name in $build_in_domain_names; do + if [[ $domain_name == *"$build_in_domain_name"* ]]; then + exist=true + break + fi + done + if [ "$exist" = false ]; then + sub_domain_name=$(echo "$domain_name" | awk -F'.' '{print $(NF-1)"."$NF}') + wildcard="*.$sub_domain_name" + write_domain_name $wildcard + fi + fi +fi +if [ "$wrote" = false ]; then + write_domain_name "*.okta.com *.authing.cn *.amazoncognito.com *.amazonaws.com" +fi +nginx -g "daemon off;" diff --git a/source/portal/package-lock.json b/source/portal/package-lock.json index dcadf77f..21393002 100644 --- a/source/portal/package-lock.json +++ b/source/portal/package-lock.json @@ -14,12 +14,13 @@ "@testing-library/react": "^13.4.0", "@testing-library/user-event": "^13.5.0", "@types/jest": "^27.5.2", - "@types/node": "^16.18.3", + "@types/node": "^16.18.83", "@types/react": "^18.0.25", "@types/react-dom": "^18.0.8", "axios": "^1.6.8", "classnames": "^2.3.2", "date-fns": "^2.30.0", + "downloadjs": "^1.4.7", "i18next": "^22.4.15", "i18next-browser-languagedetector": "^7.0.1", "i18next-http-backend": "^2.2.0", @@ -37,6 +38,7 @@ "@babel/core": "^7.16.0", "@pmmmwh/react-refresh-webpack-plugin": "^0.5.3", "@svgr/webpack": "^5.5.0", + "@types/downloadjs": "^1.4.6", "@types/lodash": "^4.14.191", "@types/react-simple-maps": "^3.0.0", "@typescript-eslint/eslint-plugin": "^5.42.0", @@ -4335,6 +4337,12 @@ "@types/d3-selection": "^2" } }, + "node_modules/@types/downloadjs": { + "version": "1.4.6", + "resolved": "https://registry.npmjs.org/@types/downloadjs/-/downloadjs-1.4.6.tgz", + "integrity": "sha512-mp3w70vsaiLRT9ix92fmI9Ob2yJAPZm6tShJtofo2uHbN11G2i6a0ApIEjBl/kv3e9V7Pv7jMjk1bUwYWvMHvA==", + "dev": true + }, "node_modules/@types/eslint": { "version": "8.37.0", "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.37.0.tgz", @@ -4473,9 +4481,9 @@ "dev": true }, "node_modules/@types/node": { - "version": "16.18.29", - "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.29.tgz", - "integrity": "sha512-cal+XTYF4JBwG82kw3m9ktTOyUj7GXcO9i2o+t49y/OF+3asYfpHqTROF1UbV91e71g/UB5wNeL5hfqPthzp8Q==" + "version": "16.18.83", + "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.83.tgz", + "integrity": "sha512-TmBqzDY/GeCEmLob/31SunOQnqYE3ZiiuEh1U9o3HqE1E2cqKZQA5RQg4krEguCY3StnkXyDmCny75qyFLx/rA==" }, "node_modules/@types/parse-json": { "version": "4.0.0", @@ -7654,6 +7662,11 @@ "integrity": "sha512-YXQl1DSa4/PQyRfgrv6aoNjhasp/p4qs9FjJ4q4cQk+8m4r6k4ZSiEyytKG8f8W9gi8WsQtIObNmKd+tMzNTmA==", "dev": true }, + "node_modules/downloadjs": { + "version": "1.4.7", + "resolved": "https://registry.npmjs.org/downloadjs/-/downloadjs-1.4.7.tgz", + "integrity": "sha512-LN1gO7+u9xjU5oEScGFKvXhYf7Y/empUIIEAGBs1LzUq/rg5duiDrkuH5A2lQGd5jfMOb9X9usDa2oVXwJ0U/Q==" + }, "node_modules/duplexer": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/duplexer/-/duplexer-0.1.2.tgz", @@ -22507,6 +22520,12 @@ "@types/d3-selection": "^2" } }, + "@types/downloadjs": { + "version": "1.4.6", + "resolved": "https://registry.npmjs.org/@types/downloadjs/-/downloadjs-1.4.6.tgz", + "integrity": "sha512-mp3w70vsaiLRT9ix92fmI9Ob2yJAPZm6tShJtofo2uHbN11G2i6a0ApIEjBl/kv3e9V7Pv7jMjk1bUwYWvMHvA==", + "dev": true + }, "@types/eslint": { "version": "8.37.0", "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.37.0.tgz", @@ -22645,9 +22664,9 @@ "dev": true }, "@types/node": { - "version": "16.18.29", - "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.29.tgz", - "integrity": "sha512-cal+XTYF4JBwG82kw3m9ktTOyUj7GXcO9i2o+t49y/OF+3asYfpHqTROF1UbV91e71g/UB5wNeL5hfqPthzp8Q==" + "version": "16.18.83", + "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.83.tgz", + "integrity": "sha512-TmBqzDY/GeCEmLob/31SunOQnqYE3ZiiuEh1U9o3HqE1E2cqKZQA5RQg4krEguCY3StnkXyDmCny75qyFLx/rA==" }, "@types/parse-json": { "version": "4.0.0", @@ -25084,6 +25103,11 @@ "integrity": "sha512-YXQl1DSa4/PQyRfgrv6aoNjhasp/p4qs9FjJ4q4cQk+8m4r6k4ZSiEyytKG8f8W9gi8WsQtIObNmKd+tMzNTmA==", "dev": true }, + "downloadjs": { + "version": "1.4.7", + "resolved": "https://registry.npmjs.org/downloadjs/-/downloadjs-1.4.7.tgz", + "integrity": "sha512-LN1gO7+u9xjU5oEScGFKvXhYf7Y/empUIIEAGBs1LzUq/rg5duiDrkuH5A2lQGd5jfMOb9X9usDa2oVXwJ0U/Q==" + }, "duplexer": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/duplexer/-/duplexer-0.1.2.tgz", diff --git a/source/portal/package.json b/source/portal/package.json index eba593b8..06bc7410 100644 --- a/source/portal/package.json +++ b/source/portal/package.json @@ -9,12 +9,13 @@ "@testing-library/react": "^13.4.0", "@testing-library/user-event": "^13.5.0", "@types/jest": "^27.5.2", - "@types/node": "^16.18.3", + "@types/node": "^16.18.83", "@types/react": "^18.0.25", "@types/react-dom": "^18.0.8", "axios": "^1.6.8", "classnames": "^2.3.2", "date-fns": "^2.30.0", + "downloadjs": "^1.4.7", "i18next": "^22.4.15", "i18next-browser-languagedetector": "^7.0.1", "i18next-http-backend": "^2.2.0", @@ -59,6 +60,7 @@ "@babel/core": "^7.16.0", "@pmmmwh/react-refresh-webpack-plugin": "^0.5.3", "@svgr/webpack": "^5.5.0", + "@types/downloadjs": "^1.4.6", "@types/lodash": "^4.14.191", "@types/react-simple-maps": "^3.0.0", "@typescript-eslint/eslint-plugin": "^5.42.0", diff --git a/source/portal/public/aws-exports.json b/source/portal/public/aws-exports.json index 4fe2e9a7..4bfc924a 100644 --- a/source/portal/public/aws-exports.json +++ b/source/portal/public/aws-exports.json @@ -10,6 +10,6 @@ "aws_user_pools_id": "", "aws_user_pools_web_client_id": "", "version": "v1.0.0", - "backend_url": "https://sdps-dev.demo.solutions.aws.a2z.org.cn:444/", + "backend_url": "https://sdps-dev.demo.solutions.aws.a2z.org.cn:444", "expired": 12 } diff --git a/source/portal/public/locales/en/common.json b/source/portal/public/locales/en/common.json index 4a94c17a..5c66c193 100644 --- a/source/portal/public/locales/en/common.json +++ b/source/portal/public/locales/en/common.json @@ -20,7 +20,8 @@ "manageIdentifier": "Manage data identifiers", "doc": "Documentation", "version": "Version", - "dataDiscovery": "Data discovery" + "dataDiscovery": "Data discovery", + "systemSettings": "System settings" }, "breadcrumb": { "home": "Sensitive Data Protection Solution", @@ -40,6 +41,7 @@ }, "button": { "authorize": "Authorize", + "addAndAuthorize": "Add data source and authorize", "signin": "Sign In", "cancel": "Cancel", "delete": "Delete", @@ -98,12 +100,19 @@ "savexlsxSensitiveOnly": "Download .xlsx file (Sensitive data only)", "savecsvSensitiveOnly": "Download .csv file (Sensitive data only)", "addDataSource": "Add data source", + "addDataSourceBatch": "Batch add data source", "deleteDataSource": "Delete data source", "deleteDataSourceOnly": "Delete data catalog only", "disconnectDeleteCatalog": "Disconnect & Delete catalog", "deleteDB": "Delete Database", "editDataSource": "Edit data source", - "download": "Download" + "download": "Download", + "batchOperation": "Batch Operation", + "batchCreate": "Batch Create", + "batchExportDS": "Batch Export Datasource", + "batchExportIdentify": "Batch Export Identify", + "estimateIP": "Estimate IP usage", + "upload": "Upload" }, "label": { "label": "Custom label", @@ -254,8 +263,8 @@ "dbCount": "Database count", "link": "Catalog Link", "propertyValue": { - "jobName" : "Job name", - "templateId" : "Template ID", + "jobName": "Job name", + "templateId": "Template ID", "schedule": "Schedule", "description": "Description", "range": "Range", @@ -393,6 +402,7 @@ "changeEnableError": "Change enable error", "selectCategoryLabel": "Please select category/label", "pending": "Pending", + "PENDING": "PENDING", "SUCCEEDED": "SUCCEEDED", "RUNNING": "RUNNING", "FAILED": "FAILED", @@ -423,5 +433,56 @@ "glue": "Glue data catalogs", "jdbc": "Custom database (JDBC)", "structureData": "Structured data", - "unstructuredData": "Unstructured data" + "unstructuredData": "Unstructured data", + "uploadSuccess": "Upload Success", + "settings": { + "title": "Settings", + "desc": "Overall settings for sensitive data discovery jobs", + "rdsDataSourceDiscovery": "RDS Data Source Data Discovery", + "rdsDetectedConcurrency": "Number of RDS instance detected concurrently in sensitive discovery job", + "rdsDetectedConcurrencyDesc": "How many RDS instances will scanned concurrently", + "rdsSubJobRunNumber": "Number of sub-job runs can be used for 1 RDS scan", + "rdsSubJobRunNumberDesc": "How many Glue job runs can be used for 1 RDS scan", + "subnet": "Subnet ", + "subnetNameDesc": "Total number of left IPs in subnet", + "currentIPLeft": "Current IP left", + "subnetDesc": "IP usage per subnet = (3 + (Number of sub-job runs can be used for 1 RDS scan * 2)) * Number of RDS instance detected concurrently", + "estimateResult": "Based on the above settings, for each job run it will consume {{ipCount}} IPs maximum per subnet.", + "estimateError": "The IP in discovery job can not be more than the IP left of subnets. Please adjust the settings.", + "estimateSuccess": "Config validate successfully.", + "estimateFirst": "Please click validate button to validate the settings." + }, + "batch": { + "name": "Batch Operation", + "nameDescDataSource": "Operate data source in batch", + "nameDescIdentifier": "Operate identifier in batch", + "tabDataSource": "Batch data source creation", + "tabIdentifier": "Batch identifier creation", + "step1Title": "Step 1: Download Template", + "step1Desc": "Follow instruction in template to fill in information", + "step1Download": "Download template", + "step2Title": "Step 2: Follow the instruction to fill in the template", + "step2Desc": "Fill the information in the template", + "step2Tips1": "Making sure no duplicates", + "step3Title": "Step 3: Upload the template with filled information", + "uploadTitle": "Fill in the template and upload", + "fileExtensionError": "Uploaded file must have an xlsx extension.", + "chooseFiles": "Choose files", + "chooseFile": "Choose file", + "dropFilesUpload": "Drop files to upload", + "dropFileUpload": "Drop file to upload", + "removeFile": "Remove file ", + "showFewer": "Show fewer files", + "showMore": "Show more files", + "error": "Error", + "only": ".xlsx files only", + "successTitle": "Successfully create data sources", + "successDesc": "{{successCount}} succeeded, {{warningCount}} warnings. Please download the report and check the result.", + "failedTitle": "Failed create data sources in batch", + "failedDesc": "{{successCount}} succeeded, {{warningCount}} warnings, {{failedCount}} failed. Please download the report and fix the data to upload again to retry.", + "inProgress": "In progress", + "inProgressDesc": "Creating data sources, Please do not close this window. It will takes less than 15 minutes.", + "inProgressIdentifierDesc": "Creating Identifiers, Please do not close this window. It will takes less than 15 minutes.", + "dismissAlert": "Please make sure that you have downloaded the batch import report. Once this window is closed, the report will not be available for download again." + } } diff --git a/source/portal/public/locales/en/datasource.json b/source/portal/public/locales/en/datasource.json index 97d8e30b..90a8ab2d 100644 --- a/source/portal/public/locales/en/datasource.json +++ b/source/portal/public/locales/en/datasource.json @@ -7,8 +7,12 @@ "organization": "Organization", "filterBuckets": "Filter buckets", "filterInstances": "Filter instances", - "connectToRDSDataSource": "Connect to RDS data source", + "connectToRDSDataSource": "Authorize to RDS data source", "rdsInstances": "RDS instances", + "credential": "Credential", + "security": "Security group", + "chooseSg": "Choose security groups", + "emptySg": "No security groups", "connectionTips": "The connection may takes around 20-30 seconds.", "connectToDataSourceForAccount": "Connect to data source for account Id: ", "connectToDataSourceForAccountDesc": "You can create data catalogs by connecting data source. ", @@ -76,6 +80,8 @@ "selectSecret": "Please select secret", "username": "Username", "password": "Password", + "inputUsername": "Please input username", + "inputPassword": "Please input password", "networkOption": "Network options", "networkDesc": "If your Amazon Glue job needs to jdbc resource which existed in other vpc or other cloud provider environment, you must provide additional VPC-specific configuration information.", "vpc": "VPC", @@ -86,6 +92,16 @@ "chooseSubnet": "Choose one subnet", "sg": "Security groups", "sgDesc": "Choose one or more security groups to allow access to the data store in your VPC subnet. Security groups are associated to the ENI attached to your subnet. You must choose at least one security group with a self-referencing inbound rule for all TCP ports.", - "chooseSG": "Choose one or more security groups" + "chooseSG": "Choose one or more security groups", + "mysql": "MySQL (Auto discovery)", + "other": "Others", + "otherError": "Other JDBC URL can not start with 'jdbc:mysql://'", + "databaseError": "JDBC Database can not be empty.", + "removeDataSource": "Remove data sources from system", + "deleteDataSourceTips": "Delete data sources permanently? This action cannot be undone.", + "deleteDataSourceFromSystemTipsA": "Are you sure you want to remove the following", + "deleteDataSourceFromSystemTipsB": "data sources from system?", + "removeDataSourceFailed": "Following data sources failed to delete:", + "confirmReason": "Please confirm the reason and try again." } } diff --git a/source/portal/public/locales/en/identifier.json b/source/portal/public/locales/en/identifier.json index 20f4260a..c2012df7 100644 --- a/source/portal/public/locales/en/identifier.json +++ b/source/portal/public/locales/en/identifier.json @@ -32,5 +32,7 @@ "minDis": "Min. occurrence of identification rules", "minDisDesc": "This value applies on any identification rules (only Keyword enabled, only Regex enabled, or both Keywords and Regex are enabled).", "textBased": "Text based", - "imageBased": "Image based" + "imageBased": "Image based", + "IdentifierNameNull": "Identifier name must not be null.", + "RuleNull": "Content validation rules and title keywords validation rules cannot be empty at the same time." } diff --git a/source/portal/public/locales/zh/common.json b/source/portal/public/locales/zh/common.json index 6e187455..139544c0 100644 --- a/source/portal/public/locales/zh/common.json +++ b/source/portal/public/locales/zh/common.json @@ -20,7 +20,8 @@ "manageIdentifier": "管理数据识别规则", "doc": "文档", "version": "版本", - "dataDiscovery": "数据发现" + "dataDiscovery": "数据发现", + "systemSettings": "系统设置" }, "breadcrumb": { "home": "敏感数据保护解决方案", @@ -40,6 +41,7 @@ }, "button": { "authorize": "授权", + "addAndAuthorize": "添加数据源并授权", "signin": "登录", "cancel": "取消", "delete": "删除", @@ -98,12 +100,19 @@ "savexlsxSensitiveOnly": "下载 .xlsx 文件(只包含敏感数据)", "savecsvSensitiveOnly": "下载 .csv 文件(只包含敏感数据)", "addDataSource": "添加数据源", + "addDataSourceBatch": "批量添加数据源", "deleteDataSource": "删除数据源", "deleteDataSourceOnly": "仅删除数据目录", "disconnectDeleteCatalog": "断开连接并删除目录", "deleteDB": "删除数据库", "editDataSource": "编辑数据源", - "download": "下载" + "download": "下载", + "batchOperation": "批量操作", + "batchCreate": "批量创建", + "batchExportDS": "批量导出数据源", + "batchExportIdentify": "批量导出标识符", + "estimateIP": "评估 IP 使用情况", + "upload": "上传" }, "label": { "label": "标签", @@ -254,8 +263,8 @@ "dbCount": "数据库个数", "link": "目录链接", "propertyValue": { - "jobName" : "任务名", - "templateId" : "模版ID", + "jobName": "任务名", + "templateId": "模版ID", "schedule": "调度方式", "description": "描述", "range": "范围", @@ -393,6 +402,7 @@ "changeEnableError": "更改启用错误", "selectCategoryLabel": "请选择类别/标签", "pending": "等待", + "PENDING": "等待", "SUCCEEDED": "成功", "RUNNING": "运行中", "FAILED": "失败", @@ -423,5 +433,55 @@ "glue": "Glue 数据目录", "jdbc": "自定义数据库(JDBC)", "structureData": "结构化数据", - "unstructuredData": "非结构化数据" + "unstructuredData": "非结构化数据", + "uploadSuccess": "上传成功", + "settings": { + "title": "设置", + "desc": "敏感数据发现作业的总体设置", + "rdsDataSourceDiscovery": "RDS数据源数据发现", + "rdsDetectedConcurrency": "敏感发现作业中同时检测的 RDS 实例数", + "rdsDetectedConcurrencyDesc": "将同时扫描多少个 RDS 实例", + "rdsSubJobRunNumber": "扫描1个RDS实例使用的子任务个数", + "rdsSubJobRunNumberDesc": "1 次 RDS 扫描可运行多少次 Glue 作业", + "subnet": "子网 ", + "subnetNameDesc": "子网中剩余IP总数", + "subnetDesc": "每个子网的 IP 使用量 = (3 + (可用于 1 次 RDS 扫描的子作业运行数量 * 2)) * 敏感发现作业中同时检测的 RDS 实例数", + "estimateResult": "根据上述设置,对于每个作业运行,每个子网最多将消耗 {{ipCount}} 个 IP。", + "estimateError": "发现作业中的 IP 不能超过子网剩余的 IP。请调整设置。", + "estimateSuccess": "配置验证成功。", + "estimateFirst": "请单击验证按钮以验证设置。" + }, + "batch": { + "name": "批量操作", + "nameDescDataSource": "批量操作数据源", + "nameDescIdentifier": "批量操作标识符", + "tabDataSource": "批量创建数据源", + "tabIdentifier": "批量创建标识符", + "step1Title": "第 1 步:下载模板", + "step1Desc": "按照模板中的说明填写信息", + "step1Download": "下载模板", + "step2Title": "第 2 步:按照提示填写模板", + "step2Desc": "填写模板中的信息", + "step2Tips1": "确保没有重复项", + "step3Title": "第 3 步:上传填写信息的模板", + "uploadTitle": "填写模板并上传", + "fileExtensionError": "上传的文件必须是 xlsx 格式。", + "chooseFiles": "选择文件", + "chooseFile": "选择文件", + "dropFilesUpload": "拖放文件以上传", + "dropFileUpload": "拖放文件以上传", + "removeFile": "移除文件", + "showFewer": "显示较少文件", + "showMore": "显示更多文件", + "error": "错误", + "only": "仅限 .xlsx 文件", + "successTitle": "成功创建数据源", + "successDesc": "{{successCount}} 成功,{{warningCount}} 警告。请下载报告并检查结果。", + "failedTitle": "批量创建数据源失败", + "failedDesc": "{{successCount}} 成功,{{warningCount}} 警告,{{failedCount}} 失败。请下载报告并修复数据以重新上传以重试。", + "inProgress": "进行中", + "inProgressDesc": "正在创建数据源,请不要关闭此窗口。预计耗时不超过15分钟。", + "inProgressIdentifierDesc": "正在创建标识符,请不要关闭此窗口。预计耗时不超过15分钟。", + "dismissAlert": "请确保已经下载了批量导入报告,关闭此窗口后,报告将不可再次下载。" + } } diff --git a/source/portal/public/locales/zh/datasource.json b/source/portal/public/locales/zh/datasource.json index d2f04b06..16e3559e 100644 --- a/source/portal/public/locales/zh/datasource.json +++ b/source/portal/public/locales/zh/datasource.json @@ -7,8 +7,12 @@ "organization": "组织", "filterBuckets": "过滤存储桶", "filterInstances": "筛选实例", - "connectToRDSDataSource": "连接到 RDS 数据源", + "connectToRDSDataSource": "授权 RDS 数据源", "rdsInstances": "RDS 实例", + "credential": "认证方式", + "security": "安全组", + "chooseSg": "选择安全组", + "emptySg": "没有相关信息", "connectionTips": "连接可能需要大约 20-30 秒。", "connectToDataSourceForAccount": "连接到账户 ID 的数据源: ", "connectToDataSourceForAccountDesc": "您可以通过连接数据源来创建数据目录。 ", @@ -76,6 +80,8 @@ "selectSecret": "请选择密钥", "username": "用户名", "password": "密码", + "inputUsername": "请输入用户名", + "inputPassword": "请输入密码", "networkOption": "网络选项", "networkDesc": "如果你的 AWS Glue 工作需要连接到其他VPC或者其他云供应商环境中的 jdbc 资源,你需要提供额外的 VPC 特定的配置信息。", "vpc": "VPC", @@ -86,6 +92,16 @@ "chooseSubnet": "选择一个子网", "sg": "安全组", "sgDesc": "选择一个或多个安全组以允许访问在你的 VPC 子网中的数据存储。安全组与你的子网关联的ENI相关联。你必须选择至少一个对所有 TCP 端口有自我引用入站规则的安全组。", - "chooseSG": "选择一个或多个安全组" + "chooseSG": "选择一个或多个安全组", + "mysql": "MySQL(自动发现)", + "other": "其他", + "otherError": "其他 JDBC URL 不能以 'jdbc:mysql://' 开头", + "databaseError": "JDBC 数据库不能为空。", + "removeDataSource": "从系统中删除这些数据源", + "deleteDataSourceTips": "永久删除这些数据源?此操作无法撤消。", + "deleteDataSourceFromSystemTipsA": "你确定从系统移除下面", + "deleteDataSourceFromSystemTipsB": "个数据源?", + "removeDataSourceFailed": "下列数据源删除失败:", + "confirmReason": "请确认原因后重新操作。" } } diff --git a/source/portal/public/locales/zh/identifier.json b/source/portal/public/locales/zh/identifier.json index 9d9a0737..973cc268 100644 --- a/source/portal/public/locales/zh/identifier.json +++ b/source/portal/public/locales/zh/identifier.json @@ -32,5 +32,7 @@ "minDis": "最小识别规则", "minDisDesc": "此值适用于任何识别规则(仅启用关键字、仅启用正则表达式,关键字和正则表达式都启用)", "textBased": "基于文本", - "imageBased": "基于图片" + "imageBased": "基于图片", + "IdentifierNameNull": "标识符名称不能为空", + "RuleNull": "内容的校验规则和标题关键字的校验规则不能同时为空" } diff --git a/source/portal/src/apis/config/api.ts b/source/portal/src/apis/config/api.ts new file mode 100644 index 00000000..70a74985 --- /dev/null +++ b/source/portal/src/apis/config/api.ts @@ -0,0 +1,18 @@ +import { apiRequest } from 'tools/apiRequest'; + +const getSystemConfig = async (params: any) => { + const result = await apiRequest('get', 'config', params); + return result; +}; + +const getSubnetsRunIps = async (params: any) => { + const result = await apiRequest('get', 'config/subnets', params); + return result; +}; + +const updateSystemConfig = async (params: any) => { + const result = await apiRequest('post', 'config', params); + return result; +}; + +export { getSystemConfig, getSubnetsRunIps, updateSystemConfig }; diff --git a/source/portal/src/apis/data-source/api.ts b/source/portal/src/apis/data-source/api.ts index 535231f2..81f45093 100644 --- a/source/portal/src/apis/data-source/api.ts +++ b/source/portal/src/apis/data-source/api.ts @@ -46,13 +46,21 @@ const getDataSourceRdsByPage = async (params: any) => { // 分页获取DataSource Glue列表 const getDataSourceGlueByPage = async (params: any) => { - const result = await apiRequest('post', 'data-source/list-glue-database', params); + const result = await apiRequest( + 'post', + 'data-source/list-glue-database', + params + ); return result; }; // 分页获取DataSource JDBC列表 const getDataSourceJdbcByPage = async (params: any, provider_id: number) => { - const result = await apiRequest('post', `data-source/list-jdbc?provider_id=${provider_id}`, params); + const result = await apiRequest( + 'post', + `data-source/list-jdbc?provider_id=${provider_id}`, + params + ); return result; }; @@ -111,38 +119,62 @@ const hideDataSourceRDS = async (params: any) => { return result; }; -const hideDataSourceJDBC = async (params: any) => { - const result = await apiRequest('post', 'data-source/hide-jdbc', params); +const deleteDataSourceJDBC = async (params: any) => { + const result = await apiRequest('post', 'data-source/delete-jdbc', params); return result; }; const deleteDataCatalogS3 = async (params: any) => { - const result = await apiRequest('post', 'data-source/delete-catalog-s3', params); + const result = await apiRequest( + 'post', + 'data-source/delete-catalog-s3', + params + ); return result; }; const deleteDataCatalogRDS = async (params: any) => { - const result = await apiRequest('post', 'data-source/delete-catalog-rds', params); + const result = await apiRequest( + 'post', + 'data-source/delete-catalog-rds', + params + ); return result; }; const deleteDataCatalogJDBC = async (params: any) => { - const result = await apiRequest('post', 'data-source/delete-catalog-jdbc', params); + const result = await apiRequest( + 'post', + 'data-source/delete-catalog-jdbc', + params + ); return result; }; const disconnectAndDeleteS3 = async (params: any) => { - const result = await apiRequest('post', 'data-source/disconnect-delete-catalog-s3', params); + const result = await apiRequest( + 'post', + 'data-source/disconnect-delete-catalog-s3', + params + ); return result; }; const disconnectAndDeleteRDS = async (params: any) => { - const result = await apiRequest('post', 'data-source/disconnect-delete-catalog-rds', params); + const result = await apiRequest( + 'post', + 'data-source/disconnect-delete-catalog-rds', + params + ); return result; }; const disconnectAndDeleteJDBC = async (params: any) => { - const result = await apiRequest('post', 'data-source/disconnect-delete-catalog-jdbc', params); + const result = await apiRequest( + 'post', + 'data-source/disconnect-delete-catalog-jdbc', + params + ); return result; }; @@ -152,55 +184,128 @@ const connectDataSourceJDBC = async (params: any) => { }; const connectDataSourceGlue = async (params: any) => { - const result = await apiRequest('post', 'data-source/sync-glue-database', params); + const result = await apiRequest( + 'post', + 'data-source/sync-glue-database', + params + ); return result; }; -const listGlueConnection = async (params: any) => { - const result = await apiRequest('post', 'data-source/query-glue-connections', params); +const listGlueConnection = async (params: any) => { + const result = await apiRequest( + 'post', + 'data-source/query-glue-connections', + params + ); return result; }; -const importGlueConnection = async (params: any) => { - const result = await apiRequest('post', 'data-source/import-jdbc-conn', params); +const importGlueConnection = async (params: any) => { + const result = await apiRequest( + 'post', + 'data-source/import-jdbc-conn', + params + ); return result; }; -const queryNetworkInfo = async (params: any) => { - const result = await apiRequest('post', 'data-source/query-account-network', params); +const queryNetworkInfo = async (params: any) => { + const result = await apiRequest( + 'post', + 'data-source/query-account-network', + params + ); return result; }; -const queryBuckets = async (params: any) => { +const queryBuckets = async (params: any) => { const result = await apiRequest('post', 'data-source/list-buckets', params); return result; }; -const createConnection = async (params: any) => { +const createConnection = async (params: any) => { const result = await apiRequest('post', 'data-source/add-jdbc-conn', params); return result; }; -const updateConnection = async (params: any) => { - const result = await apiRequest('post', 'data-source/update-jdbc-conn', params); +const updateConnection = async (params: any) => { + const result = await apiRequest( + 'post', + 'data-source/update-jdbc-conn', + params + ); return result; }; -const queryConnectionDetails = async (params: any) => { - const result = await apiRequest('post', 'data-source/query-connection-detail', params); +const queryConnectionDetails = async (params: any) => { + const result = await apiRequest( + 'post', + 'data-source/query-connection-detail', + params + ); return result; }; -const deleteGlueDatabase = async (params: any) => { - const result = await apiRequest('post', 'data-source/delete-glue-database', params); +const deleteGlueDatabase = async (params: any) => { + const result = await apiRequest( + 'post', + 'data-source/delete-glue-database', + params + ); return result; }; -const queryJdbcDatabases = async (params: any) => { +const queryJdbcDatabases = async (params: any) => { const result = await apiRequest('post', 'data-source/jdbc-databases', params); return result; }; +const batchCreateDatasource = async (params: any) => { + const result = await apiRequest( + 'post', + 'data-source/batch-create', + params.files + ); + return result; +}; + +const queryBatchStatus = async (params: any) => { + const result = await apiRequest( + 'post', + 'data-source/query-batch-status?batch=' + params.batch, + {} + ); + return result; +}; + +const downloadDataSourceBatchFiles = async (params: any) => { + const result = await apiRequest( + 'post', + 'data-source/download-batch-file?filename=' + params.filename, + {} + ); + return result; +}; + +const exportDatasource = async (params: any) => { + const result = await apiRequest( + 'post', + 'data-source/export-datasource?key=' + params.key, + {} + ); + return result; +}; + +const deleteDSReport = async (params: any) => { + const result = await apiRequest( + 'post', + 'data-source/delete-report?key=' + params.key, + {} + ); + return result; +}; + export { getDataSourceS3ByPage, getDataSourceRdsByPage, @@ -226,7 +331,7 @@ export { getDataSourceJdbcByPage, hideDataSourceS3, hideDataSourceRDS, - hideDataSourceJDBC, + deleteDataSourceJDBC, deleteDataCatalogS3, deleteDataCatalogRDS, deleteDataCatalogJDBC, @@ -237,5 +342,10 @@ export { connectDataSourceGlue, deleteGlueDatabase, updateConnection, - queryJdbcDatabases + queryJdbcDatabases, + batchCreateDatasource, + queryBatchStatus, + downloadDataSourceBatchFiles, + exportDatasource, + deleteDSReport }; diff --git a/source/portal/src/apis/data-template/api.ts b/source/portal/src/apis/data-template/api.ts index 17994fd0..32f56a63 100644 --- a/source/portal/src/apis/data-template/api.ts +++ b/source/portal/src/apis/data-template/api.ts @@ -82,6 +82,42 @@ const getTemplateUpdateTime = async () => { return result; }; +const exportIdentify = async (params: any) => { + const result = await apiRequest( + 'post', + 'template/export-identify?key=' + params.key, + {} + ); + return result; +}; + +const deleteIdentifierReport = async (params: any) => { + const result = await apiRequest( + 'post', + 'template/delete-report?key=' + params.key, + {} + ); + return result; +}; + +const downloadIdentifierBatchFiles = async (params: any) => { + const result = await apiRequest( + 'post', + 'template/download-batch-file?filename=' + params.filename, + {} + ); + return result; +}; + +const queryIdentifierBatchStatus = async (params: any) => { + const result = await apiRequest( + 'post', + 'template/query-batch-status?batch=' + params.batch, + {} + ); + return result; +}; + export { getTemplateMappingList, deleteTemplateMapping, @@ -93,4 +129,8 @@ export { updateIdentifiers, getIndentifierInTemplate, getTemplateUpdateTime, + exportIdentify, + deleteIdentifierReport, + downloadIdentifierBatchFiles, + queryIdentifierBatchStatus }; diff --git a/source/portal/src/apis/props/api.ts b/source/portal/src/apis/props/api.ts index 23825238..19ee47dc 100644 --- a/source/portal/src/apis/props/api.ts +++ b/source/portal/src/apis/props/api.ts @@ -5,7 +5,7 @@ const requestPropsByType = async (params: { type: string }) => { const result: any = await apiRequest( 'get', `template/list-props-by-type/${params.type}`, - params + undefined ); return result; }; diff --git a/source/portal/src/apis/query/api.ts b/source/portal/src/apis/query/api.ts index d5a67392..4ea7a61e 100644 --- a/source/portal/src/apis/query/api.ts +++ b/source/portal/src/apis/query/api.ts @@ -1,4 +1,9 @@ +import axios from 'axios'; +import download from 'downloadjs'; +import { User } from 'oidc-client-ts'; import { apiRequest } from 'tools/apiRequest'; +import { AMPLIFY_CONFIG_JSON } from 'ts/common'; +import { AmplifyConfigType } from 'ts/types'; /** * 获取筛选项信息 @@ -26,4 +31,35 @@ const sendQuery = async (params: Record | undefined) => { return result; }; -export { getPropertyValues, sendQuery }; +/** + * Download debug logs + * @param params + * @returns + */ +const downloadLogAsZip = async () => { + const configJSONObj: AmplifyConfigType = localStorage.getItem( + AMPLIFY_CONFIG_JSON + ) + ? JSON.parse(localStorage.getItem(AMPLIFY_CONFIG_JSON) || '') + : {}; + const token = + process.env.REACT_APP_ENV === 'local' || + process.env.REACT_APP_ENV === 'development' + ? '' + : User.fromStorageString( + localStorage.getItem( + `oidc.user:${configJSONObj.aws_oidc_issuer}:${configJSONObj.aws_oidc_client_id}` + ) || '' + )?.id_token; + const response = await axios.get('/query/download-logs', { + headers: { + 'Content-Type': 'multipart/form-data', + Authorization: token ? `Bearer ${token}` : undefined, + }, + responseType: 'blob' + }); + download(response.data, 'aws_sdps_cloudwatch_logs.zip', 'application/zip'); +}; + +export { downloadLogAsZip, getPropertyValues, sendQuery }; + diff --git a/source/portal/src/index.scss b/source/portal/src/index.scss index 5f7ef13e..f25e65da 100644 --- a/source/portal/src/index.scss +++ b/source/portal/src/index.scss @@ -1,13 +1,15 @@ body { margin: 0; - font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "Roboto", "Oxygen", "Ubuntu", "Cantarell", "Fira Sans", - "Droid Sans", "Helvetica Neue", sans-serif; + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', + 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', + sans-serif; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } code { - font-family: source-code-pro, Menlo, Monaco, Consolas, "Courier New", monospace; + font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', + monospace; } .hand-pointer { @@ -280,7 +282,13 @@ code { color: #3a3a3a; text-shadow: 0 1px 0 rgba(255, 255, 255, 0.75); background-color: #f7f7f7; - background-image: -webkit-gradient(linear, 0 0, 0 100%, from(#ffffff), to(#d2d2d2)); + background-image: -webkit-gradient( + linear, + 0 0, + 0 100%, + from(#ffffff), + to(#d2d2d2) + ); background-image: -webkit-linear-gradient(top, #ffffff, #d2d2d2); background-image: -moz-linear-gradient(top, #ffffff, #d2d2d2); background-image: -ms-linear-gradient(top, #ffffff, #d2d2d2); @@ -313,7 +321,7 @@ code { .popover .arrow:after { z-index: -1; - content: ""; + content: ''; } .popover.top .arrow { @@ -379,20 +387,20 @@ code { } .horizon-bar-chart { - [class^="awsui_grid_"] { + [class^='awsui_grid_'] { display: none !important; } - [class^="awsui_labels-left_"] { + [class^='awsui_labels-left_'] { display: none !important; } - [class*="awsui_axis--emphasized_"] { + [class*='awsui_axis--emphasized_'] { display: none !important; } - [class^="awsui_chart-container__vertical_"] { - [class^="awsui_labels-bottom_"] { + [class^='awsui_chart-container__vertical_'] { + [class^='awsui_labels-bottom_'] { display: none !important; } @@ -471,7 +479,7 @@ code { .custom-badge { background-color: #d1d5db; color: #fff; - font-family: "Open Sans", "Helvetica Neue", Roboto, Arial, sans-serif; + font-family: 'Open Sans', 'Helvetica Neue', Roboto, Arial, sans-serif; font-size: 12px; letter-spacing: 0.005em; line-height: 22px; @@ -506,3 +514,16 @@ code { .add-jdbc-container { padding: 20px; } + +.jdbc-prefix { + // padding: 10px; + border: 2px solid #7d8998; + border-radius: 8px 0 0 8px; + position: relative; + z-index: 10; + font-size: 14px; + line-height: 22px; + margin-right: -8px; + background-color: #eee; + padding: 4px 8px 4px 12px; +} diff --git a/source/portal/src/pages/account-management/index.tsx b/source/portal/src/pages/account-management/index.tsx index b82065ad..877750f0 100644 --- a/source/portal/src/pages/account-management/index.tsx +++ b/source/portal/src/pages/account-management/index.tsx @@ -5,6 +5,7 @@ import AccountList from './componments/AccountList'; import { getSourceCoverage } from 'apis/data-source/api'; import { AppLayout, + Button, ContentLayout, Grid, Header, @@ -14,17 +15,55 @@ import { import CustomBreadCrumb from 'pages/left-menu/CustomBreadCrumb'; import Navigation from 'pages/left-menu/Navigation'; import { getAccountInfomation } from 'apis/dashboard/api'; +import { exportDatasource, deleteDSReport } from 'apis/data-source/api'; import { RouterEnum } from 'routers/routerEnum'; import { useTranslation } from 'react-i18next'; import HelpInfo from 'common/HelpInfo'; import { buildDocLink } from 'ts/common'; import ProviderTab, { ProviderType } from 'common/ProviderTab'; import { CACHE_CONDITION_KEY } from 'enum/common_types'; +import { useNavigate } from 'react-router-dom'; +import { alertMsg } from 'tools/tools'; +import { format } from 'date-fns'; +import { time } from 'console'; + const AccountManagementHeader: React.FC = () => { const { t } = useTranslation(); + const navigate = useNavigate(); + const [downloading, setDownloading] = useState(false) + + const timeStr = format(new Date(), 'yyyyMMddHHmmss'); + + const batchExport = async () => { + setDownloading(true); + const url: any = await exportDatasource({key: timeStr}); + setDownloading(false); + if (url) { + window.open(url, '_blank'); + setTimeout(() => { + deleteDSReport({key: timeStr}); + }, 2000); + } else { + alertMsg(t('noReportFile'), 'error'); + } + } + return ( -
+
+ + + + } + > {t('account:connectToDataSource')}
); @@ -51,16 +90,18 @@ const AccountManagementContent: React.FC = () => { const [loadingAccounts, setLoadingAccounts] = useState(true); useEffect(() => { - if (currentProvider) { getSourceCoverageData(currentProvider.id); } sessionStorage[CACHE_CONDITION_KEY] = JSON.stringify({ - column: "account_provider_id", - condition: "and", - operation: "in", - values: (currentProvider == null || currentProvider.id === 1)?[1, 4]:[currentProvider.id] - }) + column: 'account_provider_id', + condition: 'and', + operation: 'in', + values: + currentProvider == null || currentProvider.id === 1 + ? [1, 4] + : [currentProvider.id], + }); }, [currentProvider]); const getSourceCoverageData = async (providerId: number | string) => { diff --git a/source/portal/src/pages/batch-operation/index.tsx b/source/portal/src/pages/batch-operation/index.tsx new file mode 100644 index 00000000..f1e78168 --- /dev/null +++ b/source/portal/src/pages/batch-operation/index.tsx @@ -0,0 +1,576 @@ +import { + AppLayout, + Box, + Button, + Container, + ContentLayout, + FileUpload, + Flashbar, + FlashbarProps, + FormField, + Header, + Modal, + ProgressBar, + SpaceBetween, + StatusIndicator, + Tabs, +} from '@cloudscape-design/components'; +import React, { useEffect, useState } from 'react'; +import CustomBreadCrumb from 'pages/left-menu/CustomBreadCrumb'; +import Navigation from 'pages/left-menu/Navigation'; +import { RouterEnum } from 'routers/routerEnum'; +import { useTranslation } from 'react-i18next'; +import HelpInfo from 'common/HelpInfo'; +import { AMPLIFY_CONFIG_JSON, BATCH_SOURCE_ID, BATCH_IDENTIFIER_ID, buildDocLink } from 'ts/common'; +import axios from 'axios'; +import { BASE_URL } from 'tools/apiRequest'; +import { deleteDSReport, downloadDataSourceBatchFiles, queryBatchStatus } from 'apis/data-source/api'; +import { deleteIdentifierReport, downloadIdentifierBatchFiles, queryIdentifierBatchStatus } from 'apis/data-template/api'; +import { alertMsg } from 'tools/tools'; +import { User } from 'oidc-client-ts'; +import { AmplifyConfigType } from 'ts/types'; +import { useParams } from 'react-router-dom'; +import { TFunction } from 'i18next'; + +enum BatchOperationStatus { + NotStarted = 'NotStarted', + Inprogress = 'Inprogress', + Completed = 'Completed', + Error = 'Error', +} +interface BatchOperationContentProps { + type: string, + updateStatus: ( + status: BatchOperationStatus, + success?: number, + warning?: number, + failed?: number + ) => void; +} + +const startDownload = (url: string) => { + const link = document.createElement('a'); + link.href = url; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); +}; + +const genTypeDesc = (type: string|undefined, t: TFunction) => { + if(type === "identifier"){ + return t('common:batch.nameDescIdentifier') + } + return t('common:batch.nameDescDataSource') +} + + +const AddBatchOperationHeader: React.FC = (props:any) => { + const { t } = useTranslation(); + const { type } = props + return ( +
+ {t('common:batch.name')} +
+ ); +}; +let statusInterval: any; + +const BatchOperationContent: React.FC = ( + props: BatchOperationContentProps +) => { + const { t, i18n } = useTranslation(); + const { updateStatus,type } = props; + const [uploadDisabled, setUploadDisabled] = useState(false); + const [files, setFiles] = useState([] as any); + const [errors, setErrors] = useState([] as any); + const [uploadProgress, setUploadProgress] = useState(0); + const [loadingUpload, setLoadingUpload] = useState(false); + const [loadingDownload, setLoadingDownload] = useState(false); + + const queryStatus = async (fileId: string, type:string) => { + + try { + let status:any + if(type==="identifier"){ + status = await queryIdentifierBatchStatus({ + batch: fileId, + }); + } else { + status = await queryBatchStatus({ + batch: fileId, + }); + } + // 0: Inprogress + // data {success: 0, failed: 1, warning: 2}a + if (status?.success > 0 || status?.warning > 0 || status?.failed > 0) { + clearInterval(statusInterval); + if (status.failed > 0 || status.warning > 0) { + updateStatus( + BatchOperationStatus.Error, + status.success, + status.warning, + status.failed + ); + } else { + updateStatus( + BatchOperationStatus.Completed, + status.success, + status.warning, + status.failed + ); + } + } + // else { + // updateStatus(BatchOperationStatus.Inprogress); + // } + } catch (error) { + console.error('error:', error); + clearInterval(statusInterval); + } + }; + + const changeFile = (file: any) => { + setUploadProgress(0); + if (file && file.length > 0) { + if (file[0].name.endsWith('.xlsx') === true) { + setErrors([]); + setUploadDisabled(false); + } else { + setErrors([t('common:batch.fileExtensionError')]); + setUploadDisabled(true); + } + } + setFiles(file); + }; + + const handleUpload = async (type:string) => { + const formData = new FormData(); + let questDomain = 'data-source/batch-create'; + formData.append('files', files[0]); + setLoadingUpload(true); + updateStatus(BatchOperationStatus.Inprogress); + try { + const configJSONObj: AmplifyConfigType = localStorage.getItem( + AMPLIFY_CONFIG_JSON + ) + ? JSON.parse(localStorage.getItem(AMPLIFY_CONFIG_JSON) || '') + : {}; + const token = + process.env.REACT_APP_ENV === 'local' || + process.env.REACT_APP_ENV === 'development' + ? '' + : User.fromStorageString( + localStorage.getItem( + `oidc.user:${configJSONObj.aws_oidc_issuer}:${configJSONObj.aws_oidc_client_id}` + ) || '' + )?.id_token; + if(type==="identifier"){ + questDomain="template/batch-create" + } + const response = await axios.post( + `${BASE_URL}/${questDomain}`, + formData, + { + headers: { + 'Content-Type': 'multipart/form-data', + Authorization: token ? `Bearer ${token}` : undefined, + }, + onUploadProgress: (progressEvent: any) => { + const percentCompleted = Math.round( + (progressEvent.loaded * 100) / progressEvent.total + ); + console.info('percentCompleted:', percentCompleted); + setUploadProgress(percentCompleted); + setFiles([]); + }, + } + ); + console.log(response.data); + setLoadingUpload(false); + if (response.data.status === 'success') { + const fileId = response.data.data; + if(type==="identifier"){ + localStorage.setItem(BATCH_IDENTIFIER_ID, fileId); + } else{ + localStorage.setItem(BATCH_SOURCE_ID, fileId); + } + updateStatus(BatchOperationStatus.Inprogress); + statusInterval = setInterval(() => { + queryStatus(fileId, type); + }, 5000); + } else { + setUploadProgress(0); + alertMsg(response.data.message ?? '', 'error'); + } + } catch (error) { + setLoadingUpload(false); + console.error(error); + } + }; + + const downloadTemplate = async (type:string) => { + console.log('download template'); + setLoadingDownload(true); + let url:any + if(type==="identifier"){ + url = await downloadIdentifierBatchFiles({ + filename: `identifier-template-${i18n.language}`, + }); + } else { + url = await downloadDataSourceBatchFiles({ + filename: `template-${i18n.language}`, + }); + } + setLoadingDownload(false); + startDownload(url); + }; + + useEffect(() => { + let fileId: any + if(type==="identifier"){ + fileId = localStorage.getItem(BATCH_IDENTIFIER_ID); + } else { + fileId = localStorage.getItem(BATCH_SOURCE_ID); + } + if (fileId) { + queryStatus(fileId, type); + statusInterval = setInterval(() => { + queryStatus(fileId, type); + }, 5000); + } + return () => { + clearInterval(statusInterval); + }; + }, []); + + return ( + + + {t('common:batch.step1Title')} +
+ } + > +

+ {/* */} + +

+ + + {t('common:batch.step2Title')} + + } + > +
    +
  • {t('common:batch.step2Tips1')}
  • +
+
+ {t('common:batch.step3Title')} + } + > + + + { + changeFile(detail.value); + }} + value={files} + i18nStrings={{ + uploadButtonText: (e) => + e + ? t('common:batch.chooseFiles') + : t('common:batch.chooseFile'), + dropzoneText: (e) => + e + ? t('common:batch.dropFilesUpload') + : t('common:batch.dropFileUpload'), + removeFileAriaLabel: (e) => + `${t('common:batch.removeFile')} ${e + 1}`, + limitShowFewer: t('common:batch.showFewer'), + limitShowMore: t('common:batch.showMore'), + errorIconAriaLabel: t('common:batch.error'), + }} + invalid + fileErrors={errors} + accept=".xlsx" + showFileLastModified + showFileSize + showFileThumbnail + tokenLimit={1} + constraintText={t('common:batch.only')} + /> + + {uploadProgress > 0 && ( + + + {uploadProgress >= 100 && ( + {t('uploadSuccess')} + )} + + )} + + + + + ); +}; + +const genBreadcrumbItem = (type: string|undefined, t: TFunction) => { + if(type === "identifier"){ + return { + text: t('breadcrumb.manageIdentifier'), + href: RouterEnum.TemplateIdentifiers.path, + } + } + return { + text: t('breadcrumb.dataSourceConnection'), + href: RouterEnum.DataSourceConnection.path, + } +} + + +const BatchOperation: React.FC = () => { + const {type}: any = useParams(); + const { t, i18n } = useTranslation(); + const breadcrumbItems = [ + { text: t('breadcrumb.home'), href: RouterEnum.Home.path }, + genBreadcrumbItem(type, t), + ]; + const [flashBar, setFlashBar] = useState( + [] + ); + + const [successCount, setSuccessCount] = useState(0); + const [warningCount, setWarningCount] = useState(0); + const [failedCount, setFailedCount] = useState(0); + const [status, setStatus] = useState(BatchOperationStatus.NotStarted); + const [loadingDownload, setLoadingDownload] = useState(false); + const [showConfirm, setShowConfirm] = useState(false); + + useEffect(()=>{ + console.log("clear flash") + },[flashBar]) + + const downloadReport = async (type: string) => { + console.log('download report'); + setLoadingDownload(true); + const fileName = type==="identifier" ? localStorage.getItem(BATCH_IDENTIFIER_ID) : localStorage.getItem(BATCH_SOURCE_ID); + let url:any; + if (fileName) { + if(type==="identifier"){ + url = await downloadIdentifierBatchFiles({ + filename: fileName, + }); + } else { + url = await downloadDataSourceBatchFiles({ + filename: fileName, + }); + } + setLoadingDownload(false); + startDownload(url); + } + setTimeout(() => { + if(type==="identifier"){ + deleteIdentifierReport({key: localStorage.getItem(BATCH_IDENTIFIER_ID)}); + localStorage.removeItem(BATCH_IDENTIFIER_ID); + } else { + deleteDSReport({key: localStorage.getItem(BATCH_SOURCE_ID)}); + localStorage.removeItem(BATCH_SOURCE_ID); + } + setFlashBar([]); + }, 2000); + }; + + const confirmDismissNotification = (type:string) => { + if(type==="identifier"){ + deleteIdentifierReport({key: localStorage.getItem(BATCH_IDENTIFIER_ID)}) + localStorage.removeItem(BATCH_IDENTIFIER_ID); + } else { + deleteDSReport({key: localStorage.getItem(BATCH_SOURCE_ID)}) + localStorage.removeItem(BATCH_SOURCE_ID); + } + setFlashBar([]); + setShowConfirm(false); + }; + + const genTabName = (type:string, t:TFunction) => { + if(type==="identifier"){ + return t('common:batch.tabIdentifier') + } + return t('common:batch.tabDataSource') + } + + const onDismissNotification = () => { + setShowConfirm(true); + }; + + useEffect(() => { + if (status === BatchOperationStatus.Completed) { + setFlashBar([ + { + header: t('common:batch.successTitle'), + type: 'success', + dismissible: true, + content: t('common:batch.successDesc', { + successCount: successCount, + warningCount: warningCount, + }), + id: 'success', + action: ( + + ), + onDismiss: () => { + onDismissNotification(); + }, + }, + ]); + } + if (status === BatchOperationStatus.Error) { + setFlashBar([ + { + header: t('common:batch.failedTitle'), + type: 'error', + dismissible: true, + content: t('common:batch.failedDesc', { + successCount: successCount, + warningCount: warningCount, + failedCount: failedCount, + }), + id: 'error', + action: ( + + ), + onDismiss: () => { + onDismissNotification(); + }, + }, + ]); + } + if (status === BatchOperationStatus.Inprogress) { + setFlashBar([ + { + loading: true, + header: t('common:batch.inProgress'), + type: 'info', + dismissible: false, + content: type==="identifier"?t('common:batch.inProgressIdentifierDesc'):t('common:batch.inProgressDesc'), + id: 'info', + }, + ]); + } + }, [status]); + + return ( + } + tools={ + + } + content={ + }> + { + setStatus(status); + setSuccessCount(successCount ?? 0); + setWarningCount(warningCount ?? 0); + setFailedCount(failedCount ?? 0); + }} + /> + ), + }, + ]} + /> + setShowConfirm(false)} + visible={showConfirm} + footer={ + + + + + + + } + header={t('confirm')} + > + {t('common:batch.dismissAlert')} + + + } + headerSelector="#header" + breadcrumbs={} + navigation={} + navigationWidth={290} + /> + ); +}; + +export default BatchOperation; diff --git a/source/portal/src/pages/common-badge/index.tsx b/source/portal/src/pages/common-badge/index.tsx index b66bee05..500310a2 100644 --- a/source/portal/src/pages/common-badge/index.tsx +++ b/source/portal/src/pages/common-badge/index.tsx @@ -56,9 +56,15 @@ const CommonBadge: React.FC = (props: CommonBadgeProps) => { 'system-mark-badge': badgeLabel === CLSAAIFIED_TYPE.SystemMark || labelType === CLSAAIFIED_TYPE.SystemMark, + 'authorized-badge': + badgeLabel === CLSAAIFIED_TYPE.Authorized || + labelType === CLSAAIFIED_TYPE.Authorized, 'failed-badge': badgeLabel === CLSAAIFIED_TYPE.Failed || labelType === CLSAAIFIED_TYPE.Failed, + 'pending-badge': + badgeLabel === CLSAAIFIED_TYPE.Pending || + labelType === CLSAAIFIED_TYPE.Pending, }); let iconName: any = 'status-pending'; @@ -84,6 +90,9 @@ const CommonBadge: React.FC = (props: CommonBadgeProps) => { case 'SUCCEEDED': iconName = 'status-positive'; break; + case 'AUTHORIZED': + iconName = 'status-info'; + break; case 'ACTIVE': iconName = 'status-positive'; break; diff --git a/source/portal/src/pages/common-badge/style.scss b/source/portal/src/pages/common-badge/style.scss index 11462b57..e3e9b633 100644 --- a/source/portal/src/pages/common-badge/style.scss +++ b/source/portal/src/pages/common-badge/style.scss @@ -54,6 +54,10 @@ color: #037f0c; } +.authorized-badge { + color: #0972d3; +} + .contain-pii-select { min-width: 80px; position: relative; @@ -74,3 +78,7 @@ .failed-badge { color: #d91515; } + +.pending-badge { + color: #cccccc; +} diff --git a/source/portal/src/pages/common-badge/types/badge_type.ts b/source/portal/src/pages/common-badge/types/badge_type.ts index 1bdc35e2..65aa1a64 100644 --- a/source/portal/src/pages/common-badge/types/badge_type.ts +++ b/source/portal/src/pages/common-badge/types/badge_type.ts @@ -14,11 +14,13 @@ export const CLSAAIFIED_TYPE = { Manual: 'Manual', SystemMark: 'System(?)', Success: 'Success', + Authorized: 'Authorized', Unconnected: 'Unconnected', Failed: 'Failed', Completed: 'Completed', Stopped: 'Stopped', Crawling: 'Crawling', + Pending: 'Pending', }; export const PRIVARY_TYPE_DATA = { diff --git a/source/portal/src/pages/create-identifier/index.tsx b/source/portal/src/pages/create-identifier/index.tsx index 91d82fbd..19dc047b 100644 --- a/source/portal/src/pages/create-identifier/index.tsx +++ b/source/portal/src/pages/create-identifier/index.tsx @@ -3,6 +3,7 @@ import { Button, Container, ExpandableSection, + Flashbar, FormField, Grid, Header, @@ -75,6 +76,21 @@ const CreateIdentifierContent = (props: any) => { }, } = location.state || {}; const { t } = useTranslation(); + // const [errorMsg, setErrorMsg] = useState('') + const [errorItems, setErrorItems] = React.useState([ + // { + // type: "info", + // dismissible: true, + // dismissLabel: "Dismiss message", + // onDismiss: () => setItems([]), + // content: ( + // <> + // This is an info flash message. It contains{" "} + // + // ), + // id: "message_1" + // } + ] as any); const [identifierName, setIdentifierName] = useState(oldData.name); const [identifierDescription, setIdentifierDescription] = useState( oldData.description @@ -191,9 +207,45 @@ const CreateIdentifierContent = (props: any) => { return isReg; }; - const submitIdentifier = async () => { + const changeMaxDistance = (value: string)=>{ + if(value === '0') return + setMaxDistance(value) + } + + const changeMinOccurrence = (value: string)=>{ + if(value === '0') return + setMinOccurrence(value) + } + + const submitIdentifier = async (t:any) => { + console.log(keywordList) if (!identifierName) { - alertMsg(t('identifier:inputName'), 'error'); + // alertMsg(t('identifier:inputName'), 'error'); + // setErrorMsg(t('identifier:IdentifierNameNull')) + setErrorItems([...errorItems, { + type: "error", + dismissible: true, + onDismiss: () => setErrorItems([]), + content: ( + <> + {t('identifier:IdentifierNameNull')} + + ) + }]) + return; + } + + if (!patternRex && (keywordList==null || (keywordList.length===1 && !keywordList[0]))) { + setErrorItems([...errorItems, { + type: "error", + dismissible: true, + onDismiss: () => setErrorItems([]), + content: ( + <> + {t('identifier:RuleNull')} + + ) + }]) return; } if (!clkValidate(false)) { @@ -284,6 +336,7 @@ const CreateIdentifierContent = (props: any) => { }; return ( + <> { type="number" value={minOccurrence} onChange={({ detail }) => { - setMinOccurrence(detail.value); + changeMinOccurrence(detail.value); }} placeholder="2" /> @@ -522,28 +575,32 @@ const CreateIdentifierContent = (props: any) => { type="number" value={maxDistance} onChange={({ detail }) => { - setMaxDistance(detail.value); + changeMaxDistance(detail.value); }} placeholder="50" /> - -
+ {/* */} + {/*
{errorMsg &&(<> {errorMsg})}
*/} +
+ {/*
*/} { }} /> + {errorItems && } + ); }; diff --git a/source/portal/src/pages/create-job/components/JobSettings.tsx b/source/portal/src/pages/create-job/components/JobSettings.tsx index e7714378..1baf55f9 100644 --- a/source/portal/src/pages/create-job/components/JobSettings.tsx +++ b/source/portal/src/pages/create-job/components/JobSettings.tsx @@ -21,7 +21,6 @@ import { DAY_OPTIONS, MONTH_OPTIONS, } from '../types/create_data_type'; -import { alertMsg } from 'tools/tools'; import { useTranslation } from 'react-i18next'; import { IJobType, SCAN_FREQUENCY } from 'pages/data-job/types/job_list_type'; import { DEFAULT_TEMPLATE } from 'pages/data-template/types/template_type'; @@ -76,21 +75,18 @@ const JobSettings: React.FC = (props: JobSettingsProps) => { } if (tempType === 'daily') { if (!frequencyStart) { - alertMsg(t('job:selectHourOfDay'), 'error'); return; } setFrequency(`Daily, start time: ${frequencyStart.value}`); } if (tempType === 'weekly') { if (!frequencyStart) { - alertMsg(t('job:selectDayOfWeek'), 'error'); return; } setFrequency(`${t('job:weeklyStartDay')} ${frequencyStart.value}`); } if (tempType === 'monthly') { if (!frequencyStart) { - alertMsg(t('job:selectDayOfMonth'), 'error'); return; } setFrequency(`${t('job:monthlyStartDay')} ${frequencyStart.value}`); diff --git a/source/portal/src/pages/create-job/components/SelectGlueCatalog.tsx b/source/portal/src/pages/create-job/components/SelectGlueCatalog.tsx index 03ecca97..66eddce6 100644 --- a/source/portal/src/pages/create-job/components/SelectGlueCatalog.tsx +++ b/source/portal/src/pages/create-job/components/SelectGlueCatalog.tsx @@ -12,12 +12,13 @@ import { import { COLUMN_OBJECT_STR, GLUE_ACCOUNTS_COLUMNS, + GLUE_CATALOG_COLUMS, RDS_CATALOG_COLUMS, RDS_FOLDER_COLUMS, S3_CATALOG_COLUMS, } from '../types/create_data_type'; import { getDataBaseByType, searchCatalogTables } from 'apis/data-catalog/api'; -import { formatSize } from 'tools/tools'; +import { alertMsg, formatSize } from 'tools/tools'; import CommonBadge from 'pages/common-badge'; import { BADGE_TYPE, @@ -31,6 +32,7 @@ import { IDataSourceType, IJobType } from 'pages/data-job/types/job_list_type'; import { convertAccountListToJobDatabases, convertDataSourceListToJobDatabases, + convertGlueDataSourceListToJobDatabases, convertTableSourceToJobDatabases, } from '../index'; import { getAccountList } from 'apis/account-manager/api'; @@ -39,6 +41,7 @@ import { CATALOG_TABLE_FILTER_COLUMN, RDS_FILTER_COLUMN, } from 'pages/data-catalog/types/data_config'; +import { getDataSourceGlueByPage } from 'apis/data-source/api'; interface SelectS3CatalogProps { jobData: IJobType; @@ -104,27 +107,25 @@ const SelectGlueCatalog: React.FC = ( const requestParam = { page: currentPage, size: preferences.pageSize, - sort_column: '', - asc: true, - conditions: [ - { - column: 'database_type', - values: ['glue'], - condition: 'and', - }, - ] as any, + sort_column: COLUMN_OBJECT_STR.glueDatabaseCreatedTime, + asc: false, + conditions: [] as any, }; - glueQuery.tokens && - glueQuery.tokens.forEach((item: any) => { - requestParam.conditions.push({ - column: item.propertyKey, - values: [`${item.value}`], - condition: glueQuery.operation, - }); - }); - const dataResult = await getDataBaseByType(requestParam); - setGlueCatalogData((dataResult as any)?.items); - setGlueTotal((dataResult as any)?.total); + requestParam.conditions.push({ + column: COLUMN_OBJECT_STR.GlueState, + values: ['UNCONNECTED'], + condition: glueQuery.operation, + operation: '!=', + }); + const result: any = await getDataSourceGlueByPage(requestParam); + console.info('result:', result); + setIsLoading(false); + if (!result?.items) { + alertMsg(t('loadDataError'), 'error'); + return; + } + setGlueCatalogData(result.items); + setGlueTotal(result.total); setIsLoading(false); }; @@ -221,7 +222,7 @@ const SelectGlueCatalog: React.FC = ( useEffect(() => { if (jobData.glueSelectedView === GLUE_VIEW.GLUE_INSTANCE_VIEW) { changeSelectDatabases( - convertDataSourceListToJobDatabases( + convertGlueDataSourceListToJobDatabases( selectedGlueItems, jobData.database_type ) @@ -320,7 +321,7 @@ const SelectGlueCatalog: React.FC = ( }} items={glueCatalogData} filter={} - columnDefinitions={RDS_CATALOG_COLUMS.map((item) => { + columnDefinitions={GLUE_CATALOG_COLUMS.map((item) => { return { id: item.id, header: t(item.label), diff --git a/source/portal/src/pages/create-job/components/SelectJDBCCatalog.tsx b/source/portal/src/pages/create-job/components/SelectJDBCCatalog.tsx index c752e019..32a15e74 100644 --- a/source/portal/src/pages/create-job/components/SelectJDBCCatalog.tsx +++ b/source/portal/src/pages/create-job/components/SelectJDBCCatalog.tsx @@ -16,14 +16,14 @@ import { COLUMN_OBJECT_STR, } from '../types/create_data_type'; import { getDataBaseByType, searchCatalogTables } from 'apis/data-catalog/api'; -import { formatNumber, formatSize } from 'tools/tools'; +import { alertMsg, formatNumber, formatSize } from 'tools/tools'; import CommonBadge from 'pages/common-badge'; import { BADGE_TYPE, PRIVARY_TYPE_INT_DATA, } from 'pages/common-badge/types/badge_type'; import ResourcesFilter from 'pages/resources-filter'; -import { JDBC_VIEW, TABLE_NAME } from 'enum/common_types'; +import { JDBC_VIEW, SOURCE_TYPE, TABLE_NAME } from 'enum/common_types'; import { useTranslation } from 'react-i18next'; import { IDataSourceType, IJobType } from 'pages/data-job/types/job_list_type'; import { @@ -34,6 +34,7 @@ import { CATALOG_TABLE_FILTER_COLUMN, RDS_FILTER_COLUMN, } from 'pages/data-catalog/types/data_config'; +import { getDataSourceJdbcByPage } from 'apis/data-source/api'; interface SelectJDBCCatalogProps { jobData: IJobType; @@ -87,31 +88,30 @@ const SelectJDBCCatalog: React.FC = ( const requestParam = { page: currentPage, size: preferences.pageSize, - sort_column: '', - asc: true, - conditions: [ - { - column: 'database_type', - values: - jobData.database_type === 'jdbc_aws' - ? ['jdbc_aws'] - : [jobData.database_type], - operation: 'in', - condition: 'and', - }, - ] as any, + sort_column: COLUMN_OBJECT_STR.LastModifyAt, + asc: false, + conditions: [] as any, }; - jdbcQuery.tokens && - jdbcQuery.tokens.forEach((item: any) => { - requestParam.conditions.push({ - column: item.propertyKey, - values: [`${item.value}`], - condition: jdbcQuery.operation, - }); - }); - const dataResult = await getDataBaseByType(requestParam); - setJdbcCatalogData((dataResult as any)?.items); - setJdbcTotal((dataResult as any)?.total); + requestParam.conditions.push({ + column: COLUMN_OBJECT_STR.GlueState, + values: ['UNCONNECTED'], + condition: jdbcQuery.operation, + operation: '!=', + }); + + const provider_id = jobData.database_type === SOURCE_TYPE.JDBC_PROXY? 4 : parseInt(jobData.provider_id) + const result: any = await getDataSourceJdbcByPage( + requestParam, + provider_id + ); + console.info('result:', result); + setIsLoading(false); + if (!result?.items) { + alertMsg(t('loadDataError'), 'error'); + return; + } + setJdbcCatalogData(result.items); + setJdbcTotal(result.total); setIsLoading(false); }; diff --git a/source/portal/src/pages/create-job/components/SelectRDSCatalog.tsx b/source/portal/src/pages/create-job/components/SelectRDSCatalog.tsx index e4a43938..11421fe9 100644 --- a/source/portal/src/pages/create-job/components/SelectRDSCatalog.tsx +++ b/source/portal/src/pages/create-job/components/SelectRDSCatalog.tsx @@ -17,7 +17,7 @@ import { S3_CATALOG_COLUMS, } from '../types/create_data_type'; import { getDataBaseByType, searchCatalogTables } from 'apis/data-catalog/api'; -import { formatNumber, formatSize } from 'tools/tools'; +import { alertMsg, formatNumber, formatSize } from 'tools/tools'; import CommonBadge from 'pages/common-badge'; import { BADGE_TYPE, @@ -35,6 +35,7 @@ import { CATALOG_TABLE_FILTER_COLUMN, RDS_FILTER_COLUMN, } from 'pages/data-catalog/types/data_config'; +import { getDataSourceRdsByPage } from 'apis/data-source/api'; interface SelectRDSCatalogProps { jobData: IJobType; @@ -88,27 +89,25 @@ const SelectRDSCatalog: React.FC = ( const requestParam = { page: currentPage, size: preferences.pageSize, - sort_column: '', - asc: true, - conditions: [ - { - column: 'database_type', - values: ['rds'], - condition: 'and', - }, - ] as any, + sort_column: COLUMN_OBJECT_STR.RdsCreatedTime, + asc: false, + conditions: [] as any, }; - rdsQuery.tokens && - rdsQuery.tokens.forEach((item: any) => { - requestParam.conditions.push({ - column: item.propertyKey, - values: [`${item.value}`], - condition: rdsQuery.operation, - }); - }); - const dataResult = await getDataBaseByType(requestParam); - setRdsCatalogData((dataResult as any)?.items); - setRdsTotal((dataResult as any)?.total); + requestParam.conditions.push({ + column: COLUMN_OBJECT_STR.GlueState, + values: ['UNCONNECTED'], + condition: rdsQuery.operation, + operation: '!=', + }); + const result: any = await getDataSourceRdsByPage(requestParam); + console.info('result:', result); + setIsLoading(false); + if (!result?.items) { + alertMsg(t('loadDataError'), 'error'); + return; + } + setRdsCatalogData(result.items); + setRdsTotal(result.total); setIsLoading(false); }; diff --git a/source/portal/src/pages/create-job/components/SelectS3Catalog.tsx b/source/portal/src/pages/create-job/components/SelectS3Catalog.tsx index 5577914d..fb8a58fa 100644 --- a/source/portal/src/pages/create-job/components/SelectS3Catalog.tsx +++ b/source/portal/src/pages/create-job/components/SelectS3Catalog.tsx @@ -14,7 +14,7 @@ import { S3_CATALOG_FILTER_COLUMNS, } from '../types/create_data_type'; import { getDataSourceS3ByPage } from 'apis/data-source/api'; -import { formatSize } from 'tools/tools'; +import { alertMsg, formatSize } from 'tools/tools'; import CommonBadge from 'pages/common-badge'; import { BADGE_TYPE, @@ -64,7 +64,7 @@ const SelectS3Catalog: React.FC = ( query: s3Query, setQuery: setS3Query, columnList: S3_CATALOG_FILTER_COLUMNS.filter((i) => i.filter), - tableName: TABLE_NAME.CATALOG_DATABASE_LEVEL_CLASSIFICATION, + tableName: TABLE_NAME.SOURCE_S3_BUCKET, filteringPlaceholder: t('job:filterBuckets'), }; @@ -79,33 +79,25 @@ const SelectS3Catalog: React.FC = ( const requestParam = { page: currentPage, size: preferences.pageSize, - sort_column: '', - asc: true, - conditions: [ - { - column: 'glue_state', - values: ['ACTIVE'], - condition: 'and', - operation: ':', - }, - ] as any, + sort_column: COLUMN_OBJECT_STR.LastModifyAt, + asc: false, + conditions: [] as any, }; - if (s3Query.tokens) { - s3Query.tokens.forEach((item: any) => { - requestParam.conditions.push({ - column: - item.propertyKey === COLUMN_OBJECT_STR.DatabaseName - ? COLUMN_OBJECT_STR.BucketName - : item.propertyKey, - values: [`${item.value}`], - condition: s3Query.operation, - }); - }); - } - const dataResult = await getDataSourceS3ByPage(requestParam); - setS3CatalogData((dataResult as any)?.items); - setS3Total((dataResult as any)?.total); + requestParam.conditions.push({ + column: COLUMN_OBJECT_STR.GlueState, + values: ['UNCONNECTED'], + condition: s3Query.operation, + operation: '!=', + }); + const result: any = await getDataSourceS3ByPage(requestParam); + console.info('result:', result); setIsLoading(false); + if (!result?.items) { + alertMsg(t('loadDataError'), 'error'); + return; + } + setS3CatalogData(result.items); + setS3Total(result.total); }; const buildPrivacyColumn = (item: any, e: any) => { diff --git a/source/portal/src/pages/create-job/index.tsx b/source/portal/src/pages/create-job/index.tsx index fb2affc1..0e7fd4e7 100644 --- a/source/portal/src/pages/create-job/index.tsx +++ b/source/portal/src/pages/create-job/index.tsx @@ -37,6 +37,7 @@ import SelectRDSCatalog from './components/SelectRDSCatalog'; import SelectGlueCatalog from './components/SelectGlueCatalog'; import SelectJDBCCatalog from './components/SelectJDBCCatalog'; import { IAccountData } from 'pages/account-management/types/account_type'; +import moment from 'moment'; export const convertAccountListToJobDatabases = ( accountList: IAccountData[], @@ -62,7 +63,22 @@ export const convertDataSourceListToJobDatabases = ( account_id: element.account_id, region: element.region, database_type: source_type, - database_name: element.database_name, + database_name: element.instance_id, + table_name: '', + }; + }); +}; + +export const convertGlueDataSourceListToJobDatabases = ( + dataSources: IDataSourceType[], + source_type: string +) => { + return dataSources.map((element) => { + return { + account_id: element.account_id, + region: element.region, + database_type: source_type, + database_name: element.glue_database_name, table_name: '', }; }); @@ -239,6 +255,40 @@ const CreateJobContent = () => { return true; }; + const cronGeneratorForGlueDaily = (time: string) => { + const timeMoment = moment(time, 'HH:mm'); + timeMoment.subtract(8, 'hours'); + const hours = timeMoment.format('H'); + const minutes = timeMoment.format('m'); + return `${minutes} ${hours} * * ? *`; + }; + + const clkFrequencyApply = () => { + const tempType = jobData.frequencyType; + if (tempType === 'on_demand_run') { + return true; + } + if (tempType === 'daily') { + if (!jobData.frequencyTimeStart) { + alertMsg(t('job:selectHourOfDay'), 'error'); + return false; + } + } + if (tempType === 'weekly') { + if (!jobData.frequencyTimeStart) { + alertMsg(t('job:selectDayOfWeek'), 'error'); + return false; + } + } + if (tempType === 'monthly') { + if (!jobData.frequencyTimeStart) { + alertMsg(t('job:selectDayOfMonth'), 'error'); + return false; + } + } + return true; + }; + const submitCreateJob = async () => { setIsLoading(true); let tempFrequency = @@ -259,9 +309,11 @@ const CreateJobContent = () => { // Format the UTC hour as a string utcHourString = utcHourNormalized.toString().padStart(2, '0'); } - + console.info('jobData.frequencyType:', jobData.frequencyType); if (jobData.frequencyType === 'daily') { - tempFrequency = `0 ${utcHourString} * * ? *`; + tempFrequency = cronGeneratorForGlueDaily( + jobData.frequencyTimeStart?.value ?? '' + ); } if (jobData.frequencyType === 'weekly') { const tempTime = @@ -314,7 +366,7 @@ const CreateJobContent = () => { try { const result: any = await createJob(requestParamJob); if (result && result.id && jobData.frequencyType === 'on_demand_run') { - await startJob(result); + startJob(result); } setIsLoading(true); alertMsg(t('submitSuccess'), 'success'); @@ -352,6 +404,12 @@ const CreateJobContent = () => { onSubmit={submitCreateJob} onCancel={cancelCreateJob} onNavigate={({ detail }) => { + console.info(detail); + if (detail.requestedStepIndex === 3) { + if (!clkFrequencyApply()) { + return; + } + } const checkResult = checkMustData(detail.requestedStepIndex); checkResult && setActiveStepIndex(detail.requestedStepIndex); }} @@ -371,11 +429,11 @@ const CreateJobContent = () => { }); }} changeDataSource={(sId) => { - sessionStorage[CACHE_CONDITION_KEY]=JSON.stringify({ + sessionStorage[CACHE_CONDITION_KEY] = JSON.stringify({ column: 'database_type', - condition: "and", - operation: "in", - values: [sId] + condition: 'and', + operation: 'in', + values: [sId], }); setJobData((prev) => { return { diff --git a/source/portal/src/pages/create-job/types/create_data_type.ts b/source/portal/src/pages/create-job/types/create_data_type.ts index 26f47f43..683e1a49 100644 --- a/source/portal/src/pages/create-job/types/create_data_type.ts +++ b/source/portal/src/pages/create-job/types/create_data_type.ts @@ -13,11 +13,14 @@ export const COLUMN_OBJECT_STR = { Region: 'region', LastModifyBy: 'modify_by', LastModifyAt: 'modify_time', + GlueState: 'glue_state', + RdsCreatedTime: 'created_time', + glueDatabaseCreatedTime: 'create_time', }; export const S3_CATALOG_FILTER_COLUMNS = [ { - id: COLUMN_OBJECT_STR.DatabaseName, + id: 'bucket_name', label: 'table.label.bucketName', filter: true, }, @@ -70,20 +73,48 @@ export const S3_CATALOG_COLUMS = [ export const RDS_CATALOG_COLUMS = [ { - id: 'database_name', + id: 'instance_id', label: 'table.label.instanceName', filter: true, }, + // { + // id: 'object_count', + // label: 'table.label.tables', + // filter: true, + // }, + // { + // id: 'privacy', + // label: 'table.label.privacy', + // filter: true, + // }, { - id: 'object_count', - label: 'table.label.tables', + id: 'account_id', + label: 'table.label.awsAccount', filter: true, }, { - id: 'privacy', - label: 'table.label.privacy', + id: 'region', + label: 'table.label.awsRegion', + filter: true, + }, +]; + +export const GLUE_CATALOG_COLUMS = [ + { + id: 'glue_database_name', + label: 'table.label.instanceName', filter: true, }, + // { + // id: 'object_count', + // label: 'table.label.tables', + // filter: true, + // }, + // { + // id: 'privacy', + // label: 'table.label.privacy', + // filter: true, + // }, { id: 'account_id', label: 'table.label.awsAccount', @@ -97,19 +128,24 @@ export const RDS_CATALOG_COLUMS = [ ]; export const JDBC_INSTANCE_COLUMS = [ + // { + // id: COLUMN_OBJECT_STR.ConnectionName, + // label: 'table.label.connectionName', + // filter: true, + // }, + // { + // id: 'object_count', + // label: 'table.label.tables', + // filter: true, + // }, + // { + // id: 'privacy', + // label: 'table.label.privacy', + // filter: true, + // }, { - id: COLUMN_OBJECT_STR.ConnectionName, - label: 'table.label.connectionName', - filter: true, - }, - { - id: 'object_count', - label: 'table.label.tables', - filter: true, - }, - { - id: 'privacy', - label: 'table.label.privacy', + id: 'instance_id', + label: 'table.label.instanceName', filter: true, }, { diff --git a/source/portal/src/pages/data-job/types/job_list_type.ts b/source/portal/src/pages/data-job/types/job_list_type.ts index 7b4f5012..1bd14ef7 100644 --- a/source/portal/src/pages/data-job/types/job_list_type.ts +++ b/source/portal/src/pages/data-job/types/job_list_type.ts @@ -125,6 +125,8 @@ export interface IDataSourceType { sensitivity: string; labels: Array; table_name: string; + instance_id: string; + glue_database_name: string; } export interface IDataSourceS3BucketType { bucket_name: string; diff --git a/source/portal/src/pages/data-source-connection/componments/DataSourceInfo.tsx b/source/portal/src/pages/data-source-connection/componments/DataSourceInfo.tsx index 1af8f10a..d6e8b827 100644 --- a/source/portal/src/pages/data-source-connection/componments/DataSourceInfo.tsx +++ b/source/portal/src/pages/data-source-connection/componments/DataSourceInfo.tsx @@ -7,9 +7,6 @@ import { getProviderByProviderId } from 'enum/common_types'; const DataSourceInfo: React.FC = ({ accountData }: any) => { const { t } = useTranslation(); - useEffect(() => { - console.log('accountData is', accountData); - }); const [providerType, setProviderType] = useState('AWS'); // genProvider(accountData.account_provider_id) diff --git a/source/portal/src/pages/data-source-connection/componments/DataSourceList.tsx b/source/portal/src/pages/data-source-connection/componments/DataSourceList.tsx index b6f0ccd9..976b58b5 100644 --- a/source/portal/src/pages/data-source-connection/componments/DataSourceList.tsx +++ b/source/portal/src/pages/data-source-connection/componments/DataSourceList.tsx @@ -16,6 +16,10 @@ import { ButtonDropdown, ButtonDropdownProps, StatusIndicator, + Multiselect, + Icon, + Popover, + // Flashbar, } from '@cloudscape-design/components'; import { DATA_TYPE_ENUM, TABLE_NAME } from 'enum/common_types'; import { @@ -47,7 +51,7 @@ import { getSecrets, hideDataSourceS3, hideDataSourceRDS, - hideDataSourceJDBC, + deleteDataSourceJDBC, deleteDataCatalogS3, deleteDataCatalogRDS, deleteDataCatalogJDBC, @@ -70,9 +74,13 @@ import { useTranslation } from 'react-i18next'; import JDBCConnection from './JDBCConnection'; import JDBCConnectionEdit from './JDBCConnectionEdit'; import DataSourceCatalog from './DataSourceCatalog'; +import DataSourceDelete from 'pages/data-source-delete'; const DataSourceList: React.FC = memo((props: any) => { const { tagType, accountData } = props; + const [delFailedItems, setDelFailedItems] = React.useState([] as any); + const [showDelResModal, setShowDelResModal] = useState(false) + const [isShowDelete, setIsShowDelete] = useState(false); const { t } = useTranslation(); const columnList = tagType === DATA_TYPE_ENUM.s3 @@ -108,10 +116,12 @@ const DataSourceList: React.FC = memo((props: any) => { const filterTableName = tagType === DATA_TYPE_ENUM.s3 ? TABLE_NAME.SOURCE_S3_BUCKET - : (tagType === DATA_TYPE_ENUM.rds ? (TABLE_NAME.SOURCE_RDS_INSTANCE) - : (tagType === DATA_TYPE_ENUM.glue ? TABLE_NAME.SOURCE_GLUE_DATABASE - : TABLE_NAME.SOURCE_JDBC_CONNECTION)) - // : TABLE_NAME.SOURCE_RDS_INSTANCE; + : tagType === DATA_TYPE_ENUM.rds + ? TABLE_NAME.SOURCE_RDS_INSTANCE + : tagType === DATA_TYPE_ENUM.glue + ? TABLE_NAME.SOURCE_GLUE_DATABASE + : TABLE_NAME.SOURCE_JDBC_CONNECTION; + // : TABLE_NAME.SOURCE_RDS_INSTANCE; const resFilterProps = { totalCount, columnList: columnList.filter((i) => i.filter), @@ -126,6 +136,8 @@ const DataSourceList: React.FC = memo((props: any) => { const [showAddConnection, setShowAddConnection] = useState(false); const [showEditConnection, setShowEditConnection] = useState(false); + const [sgs, setSgs] = useState([] as any); + const [selectedSgs, setSelectedSgs] = useState([] as any); useEffect(() => { if (tagType === DATA_TYPE_ENUM.jdbc && !showAddConnection) { @@ -171,11 +183,16 @@ const DataSourceList: React.FC = memo((props: any) => { id: 'connectAll', disabled: tagType === DATA_TYPE_ENUM.rds, }, - { - text: t('button.addDataSource'), - id: 'addDataSource', - disabled: tagType !== DATA_TYPE_ENUM.jdbc, - }, + // { + // text: t('button.addDataSource'), + // id: 'addDataSource', + // disabled: tagType !== DATA_TYPE_ENUM.jdbc, + // }, + // { + // text: t('button.addDataSourceBatch'), + // id: 'addDataSourceBatch', + // disabled: tagType !== DATA_TYPE_ENUM.jdbc, + // }, { text: t('button.deleteDataSource'), id: 'deleteDataSource', @@ -203,19 +220,19 @@ const DataSourceList: React.FC = memo((props: any) => { } if (tagType === DATA_TYPE_ENUM.jdbc) { res = [ - { - text: t('button.addDataSource'), - id: 'addImportJdbc', - disabled: tagType !== DATA_TYPE_ENUM.jdbc, - }, + // { + // text: t('button.addDataSource'), + // id: 'addImportJdbc', + // disabled: tagType !== DATA_TYPE_ENUM.jdbc, + // }, { text: t('button.editDataSource'), id: 'editJdbc', - disabled: selectedItems.length === 0, + disabled: selectedItems.length !== 1, }, { text: t('button.deleteDataSource'), - id: 'disconnect_dc', + id: 'deleteDataSource', disabled: tagType === DATA_TYPE_ENUM.rds || selectedItems.length === 0, }, @@ -236,6 +253,11 @@ const DataSourceList: React.FC = memo((props: any) => { return res; }; + const closeDelResModal =()=>{ + setDelFailedItems([]) + setShowDelResModal(false) + } + const getPageData = async () => { setIsLoading(true); setSelectedItems([]); @@ -401,9 +423,18 @@ const DataSourceList: React.FC = memo((props: any) => { try { await connectDataSourceJDBC(requestParam); showHideSpinner(false); - alertMsg(t('startConnect'), 'success'); setSelectedItems([]); getPageData(); + alertMsg(t('startConnect'), 'success'); + // setItems([{ + // header: "Failed to update 4 instances", + // type: "error", + // content: "This is a dismissible error message.", + // dismissible: true, + // dismissLabel: "Dismiss message", + // onDismiss: () => setItems([]), + // id: "message_1" + // }]) } catch (error) { setSelectedItems([]); showHideSpinner(false); @@ -592,8 +623,9 @@ const DataSourceList: React.FC = memo((props: any) => { } else if (tagType === DATA_TYPE_ENUM.rds) { requestParam.instance = selectedItems[0].instance_id; } else if (tagType === DATA_TYPE_ENUM.jdbc) { - requestParam.instance = selectedItems[0].instance_id; - requestParam.account_provider = selectedItems[0].account_provider_id; + // requestParam.instances = selectedItems[0].instance_id; + // requestParam.instances = selectedItems.map((item: any)=>item.instance_id); + // requestParam.account_provider = selectedItems[0].account_provider_id; } showHideSpinner(true); try { @@ -602,12 +634,15 @@ const DataSourceList: React.FC = memo((props: any) => { } else if (tagType === DATA_TYPE_ENUM.rds) { await hideDataSourceRDS(requestParam); } else if (tagType === DATA_TYPE_ENUM.jdbc) { - await hideDataSourceJDBC(requestParam); + setIsShowDelete(true) + showHideSpinner(false); + return } showHideSpinner(false); - alertMsg(t('disconnectSuccess'), 'success'); + alertMsg(t('deleteSuccess'), 'success'); setSelectedItems([]); getPageData(); + return; } catch (error) { setSelectedItems([]); @@ -615,6 +650,54 @@ const DataSourceList: React.FC = memo((props: any) => { } }; + const confirmDelete = async () => { + const requestParam: any = { + account_id: selectedItems[0].account_id, + region: selectedItems[0].region, + instances: selectedItems.map((item: any)=>item.instance_id), + account_provider: selectedItems[0].account_provider_id + }; + + + showHideSpinner(true); + // try { + // await deleteTemplateMapping(requestParam.ids); + // alertMsg(t('deleteSuccess'), 'success'); + // showHideSpinner(false); + // getPageData(); + // setIsShowDelete(false); + // } catch (error) { + // showHideSpinner(false); + // } + try { + const response:[string, string] = await deleteDataSourceJDBC(requestParam) as [string, string]; + const res = response.filter(item=>item[1] !== '') + setDelFailedItems(res) + console.log("delete jdbc res is >>>>>>"+response) + showHideSpinner(false); + setSelectedItems([]); + getPageData(); + setIsShowDelete(false); + if(res.length === 0){ + alertMsg(t('deleteSuccess'), 'success'); + } else { + setShowDelResModal(true) + } + return; + } catch (error) { + showHideSpinner(false); + } + }; + + const deleteModalProps = { + isShowDelete, + setIsShowDelete, + confirmDelete, + selectedInstances: selectedItems.map((item: any)=>item.instance_id), + title: t('datasource:jdbc.removeDataSource'), + confirmText: t('button.remove'), + }; + const loadAccountSecrets = async () => { const requestParam = { provider: accountData.account_provider_id, @@ -646,6 +729,7 @@ const DataSourceList: React.FC = memo((props: any) => { return ( <> + {/* */} = memo((props: any) => { tempType = CLSAAIFIED_TYPE.SystemMark; tempIsLoading = true; break; + case 'AUTHORIZED': + tempLabel = 'AUTHORIZED'; + tempType = CLSAAIFIED_TYPE.Authorized; + break; case 'ACTIVE': tempLabel = 'ACTIVE'; tempType = CLSAAIFIED_TYPE.Success; @@ -900,14 +988,21 @@ const DataSourceList: React.FC = memo((props: any) => { disabled={isLoading} iconName="refresh" /> - + } + { tagType === DATA_TYPE_ENUM.jdbc && + + } { if (detail.id === 'disconnect') { @@ -919,9 +1014,9 @@ const DataSourceList: React.FC = memo((props: any) => { if (detail.id === 'deleteDatabase') { clkDeleteDatabase(); } - if (detail.id === 'addImportJdbc') { - clkAddSource('addImportJdbc'); - } + // if (detail.id === 'addImportJdbc') { + // clkAddSource('addImportJdbc'); + // } if (detail.id === 'editJdbc') { clkAddSource('editJdbc'); } @@ -959,7 +1054,7 @@ const DataSourceList: React.FC = memo((props: any) => { } items={pageData} - selectionType="single" + selectionType={tagType === DATA_TYPE_ENUM.jdbc?"multi":"single"} loadingText={t('table.loadingResources') || ''} visibleColumns={preferences.visibleContent} empty={ @@ -1040,6 +1135,30 @@ const DataSourceList: React.FC = memo((props: any) => { } loading={isLoading} /> + + setShowDelResModal(false)} + footer={ + + + + + + } + header={t('datasource:jdbc.removeDataSource')} + > +

{t('datasource:jdbc.removeDataSourceFailed')}

+ {delFailedItems.map((item:any) => { + return ( + + {item[0]} - {item[1]} + + + ) + })} +

{t('datasource:jdbc.confirmReason')}

+
setShowRdsPwdModal(false)} @@ -1056,7 +1175,7 @@ const DataSourceList: React.FC = memo((props: any) => { onClick={connectRDS} loading={btnDisabled} > - {t('button.connect')} + {t('button.authorize')} @@ -1115,7 +1234,19 @@ const DataSourceList: React.FC = memo((props: any) => {

- + {/* + + setSelectedSgs(detail.selectedOptions) + } + options={sgs} + empty={t('datasource:emptySg')||''} + placeholder={t('datasource:chooseSg')||''} + /> + */} + + setCedentialType(detail.value)} value={cedentialType} diff --git a/source/portal/src/pages/data-source-connection/componments/JDBCConnection.tsx b/source/portal/src/pages/data-source-connection/componments/JDBCConnection.tsx index fcbea762..7cf2b198 100644 --- a/source/portal/src/pages/data-source-connection/componments/JDBCConnection.tsx +++ b/source/portal/src/pages/data-source-connection/componments/JDBCConnection.tsx @@ -12,6 +12,7 @@ import { Tiles, Textarea, Modal, + Grid, } from '@cloudscape-design/components'; import { listGlueConnection, @@ -25,8 +26,8 @@ import { import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { alertMsg } from 'tools/tools'; -import { i18ns } from '../types/s3_selector_config'; import { DropdownStatusProps } from '@cloudscape-design/components/internal/components/dropdown-status'; +import { checkJDBCIsMySQL } from 'ts/common'; interface JDBCConnectionProps { providerId: number; @@ -102,6 +103,11 @@ const JDBCConnection: React.FC = ( const [secretItem, setSecretItem] = useState(null); const [loadingJdbcDatabase, setLoadingJdbcDatabase] = useState(false); + const [jdbcConnType, setJdbcConnType] = useState('mysql'); + const [tmpJDBCUrl, setTmpJDBCUrl] = useState(''); + const [otherJDBCUrlError, setOtherJDBCUrlError] = useState(false); + const [jdbcDatabaseEmptyError, setJdbcDatabaseEmptyError] = useState(false); + useEffect(() => { if (credentialType === 'secret_manager') { loadAccountSecrets(); @@ -341,6 +347,14 @@ const JDBCConnection: React.FC = ( }; const addJdbcConnection = async () => { + if (jdbcConnType === 'other' && checkJDBCIsMySQL(tmpJDBCUrl)) { + setOtherJDBCUrlError(true); + return; + } + if (!jdbcConnectionData?.new?.jdbc_connection_schema?.trim()) { + setJdbcDatabaseEmptyError(true); + return; + } setLoadingImport(true); if (jdbcConnectionData.createType === 'import') { try { @@ -390,6 +404,14 @@ const JDBCConnection: React.FC = ( setJdbcConnectionData({ ...jdbcConnectionData, new: temp }); }; + useEffect(() => { + let jdbcURLStr = tmpJDBCUrl; + if (jdbcConnType === 'mysql') { + jdbcURLStr = 'jdbc:mysql://' + tmpJDBCUrl; + } + changeJDBCUrl(jdbcURLStr); + }, [tmpJDBCUrl]); + const changeDatabase = (detail: any) => { // console.log(detail) let temp = jdbcConnectionData.new; @@ -461,39 +483,39 @@ const JDBCConnection: React.FC = ( setBuckets(res); }; - const changeJDBCcertificate = (detail: any) => { - let temp = jdbcConnectionData.new; - temp = { ...temp, custom_jdbc_cert: detail.resource.uri }; - setJdbcConnectionData({ ...jdbcConnectionData, new: temp }); - }; - - const changeSkipCerValid = (detail: any) => { - // console.log("skip!!!",detail) - let temp = jdbcConnectionData.new; - temp = { - ...temp, - skip_custom_jdbc_cert_validation: detail ? 'true' : 'false', - }; - setJdbcConnectionData({ ...jdbcConnectionData, new: temp }); - }; - - const changeJDBCCertString = (detail: any) => { - let temp = jdbcConnectionData.new; - temp = { ...temp, custom_jdbc_cert_string: detail }; - setJdbcConnectionData({ ...jdbcConnectionData, new: temp }); - }; - - const changeDriverClassName = (detail: any) => { - let temp = jdbcConnectionData.new; - temp = { ...temp, jdbc_driver_class_name: detail }; - setJdbcConnectionData({ ...jdbcConnectionData, new: temp }); - }; - - const changeDriverPath = (detail: any) => { - let temp = jdbcConnectionData.new; - temp = { ...temp, jdbc_driver_jar_uri: detail.resource.uri }; - setJdbcConnectionData({ ...jdbcConnectionData, new: temp }); - }; + // const changeJDBCcertificate = (detail: any) => { + // let temp = jdbcConnectionData.new; + // temp = { ...temp, custom_jdbc_cert: detail.resource.uri }; + // setJdbcConnectionData({ ...jdbcConnectionData, new: temp }); + // }; + + // const changeSkipCerValid = (detail: any) => { + // // console.log("skip!!!",detail) + // let temp = jdbcConnectionData.new; + // temp = { + // ...temp, + // skip_custom_jdbc_cert_validation: detail ? 'true' : 'false', + // }; + // setJdbcConnectionData({ ...jdbcConnectionData, new: temp }); + // }; + + // const changeJDBCCertString = (detail: any) => { + // let temp = jdbcConnectionData.new; + // temp = { ...temp, custom_jdbc_cert_string: detail }; + // setJdbcConnectionData({ ...jdbcConnectionData, new: temp }); + // }; + + // const changeDriverClassName = (detail: any) => { + // let temp = jdbcConnectionData.new; + // temp = { ...temp, jdbc_driver_class_name: detail }; + // setJdbcConnectionData({ ...jdbcConnectionData, new: temp }); + // }; + + // const changeDriverPath = (detail: any) => { + // let temp = jdbcConnectionData.new; + // temp = { ...temp, jdbc_driver_jar_uri: detail.resource.uri }; + // setJdbcConnectionData({ ...jdbcConnectionData, new: temp }); + // }; const changeUserName = (detail: any) => { let temp = jdbcConnectionData.new; @@ -515,6 +537,11 @@ const JDBCConnection: React.FC = ( }; const findDatabase = async () => { + if (jdbcConnType === 'other' && checkJDBCIsMySQL(tmpJDBCUrl)) { + setOtherJDBCUrlError(true); + return; + } + setLoadingImport(true); setLoadingJdbcDatabase(true); const requestParam = { @@ -522,6 +549,7 @@ const JDBCConnection: React.FC = ( username: jdbcConnectionData.new.master_username, password: jdbcConnectionData.new.password, secret_id: jdbcConnectionData.new.secret, + ssl_verify_cert: jdbcConnectionData.new.jdbc_enforce_ssl === "true" ? true: false }; try { const res: any = await queryJdbcDatabases(requestParam); @@ -734,18 +762,55 @@ const JDBCConnection: React.FC = ( value={jdbcConnectionData.new.description} /> + + + { + setOtherJDBCUrlError(false); + setTmpJDBCUrl(''); + changeDatabase(''); + setJdbcConnType(detail.value); + }} + value={jdbcConnType} + items={[ + { + label: t('datasource:jdbc.mysql'), + value: 'mysql', + }, + { label: t('datasource:jdbc.other'), value: 'other' }, + ]} + /> + + <> - changeJDBCUrl(e.detail.value)} - placeholder="jdbc:protocol://host:port" - value={jdbcConnectionData.new.jdbc_connection_url} - /> +
+ {jdbcConnType === 'mysql' && ( +
jdbc:mysql://
+ )} +
+ { + setOtherJDBCUrlError(false); + setTmpJDBCUrl(e.detail.value); + }} + placeholder={ + jdbcConnType === 'mysql' + ? 'host:port' + : 'jdbc:protocol://host:port' + } + value={tmpJDBCUrl} + /> +
+
{/* = ( {credential === 'secret' && ( + { + changeDatabase(''); changeUserName(detail.value); }} /> { + changeDatabase(''); changePassword(detail.value); }} /> - - )} - */} + {props.providerId !== 1 && ( +
- ) +
+ )} + {/*
*/} +
+ )} +