Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Download Worksheet Results #1547

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions backend/dataall/base/aws/s3_client.py
dlpzx marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from botocore.exceptions import ClientError
import logging
from dataall.base.db.exceptions import AWSResourceNotFound

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -33,3 +34,31 @@ def get_presigned_url(region, bucket, key, expire_minutes: int = 15):
except ClientError as e:
log.error(f'Failed to get presigned URL due to: {e}')
raise e

@staticmethod
def object_exists(region, bucket, key) -> bool:
try:
S3_client.client(region, None).head_object(Bucket=bucket, Key=key)
return True
except ClientError as e:
dlpzx marked this conversation as resolved.
Show resolved Hide resolved
log.error(f'Failed to check object existence due to: {e}')
if e.response['Error']['Code'] == '404':
return False
raise AWSResourceNotFound('s3_object_exists', f'Object {key} not found in bucket {bucket}')

@staticmethod
def put_object(region, bucket, key, body):
try:
S3_client.client(region, None).put_object(Bucket=bucket, Key=key, Body=body)
except ClientError as e:
log.error(f'Failed to put object due to: {e}')
raise e

@staticmethod
def get_object(region, bucket, key):
try:
response = S3_client.client(region, None).get_object(Bucket=bucket, Key=key)
return response['Body'].read().decode('utf-8')
except ClientError as e:
log.error(f'Failed to get object due to: {e}')
raise e
11 changes: 11 additions & 0 deletions backend/dataall/modules/worksheets/api/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,14 @@
gql.Argument(name='measures', type=gql.ArrayType(gql.Ref('WorksheetMeasureInput'))),
],
)


WorksheetQueryResultDownloadUrlInput = gql.InputType(
name='WorksheetQueryResultDownloadUrlInput',
arguments=[
gql.Argument(name='athenaQueryId', type=gql.NonNullableType(gql.String)),
gql.Argument(name='fileFormat', type=gql.NonNullableType(gql.String)),
gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)),
gql.Argument(name='worksheetUri', type=gql.NonNullableType(gql.String)),
],
)
16 changes: 15 additions & 1 deletion backend/dataall/modules/worksheets/api/mutations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from dataall.base.api import gql
from dataall.modules.worksheets.api.resolvers import create_worksheet, delete_worksheet, update_worksheet
from dataall.modules.worksheets.api.resolvers import (
create_worksheet,
delete_worksheet,
update_worksheet,
create_athena_query_result_download_url,
)


createWorksheet = gql.MutationField(
Expand Down Expand Up @@ -27,3 +32,12 @@
],
type=gql.Boolean,
)

createWorksheetQueryResultDownloadUrl = gql.MutationField(
name='createWorksheetQueryResultDownloadUrl',
resolver=create_athena_query_result_download_url,
args=[
gql.Argument(name='input', type=gql.Ref('WorksheetQueryResultDownloadUrlInput')),
],
type=gql.Ref('WorksheetQueryResult'),
)
21 changes: 20 additions & 1 deletion backend/dataall/modules/worksheets/api/resolvers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataall.base.db import exceptions
from dataall.modules.worksheets.api.enums import WorksheetRole
from dataall.modules.worksheets.services.worksheet_enums import WorksheetRole, WorksheetResultsFormat
from dataall.modules.worksheets.db.worksheet_models import Worksheet
from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository
from dataall.modules.worksheets.services.worksheet_service import WorksheetService
from dataall.base.api.context import Context
from dataall.modules.worksheets.services.worksheet_query_result_service import WorksheetQueryResultService


def create_worksheet(context: Context, source, input: dict = None):
Expand Down Expand Up @@ -69,3 +70,21 @@ def run_sql_query(context: Context, source, environmentUri: str = None, workshee
def delete_worksheet(context, source, worksheetUri: str = None):
with context.engine.scoped_session() as session:
return WorksheetService.delete_worksheet(session=session, uri=worksheetUri)


def create_athena_query_result_download_url(context: Context, source, input: dict = None):

if not input:
# raise exceptions.InvalidInput('data', input, 'input is required')
dlpzx marked this conversation as resolved.
Show resolved Hide resolved
raise exceptions.RequiredParameter('data')
if not input.get('athenaQueryId'):
raise exceptions.RequiredParameter('athenaQueryId')
if not input.get('fileFormat'):
raise exceptions.RequiredParameter('fileFormat')
if not hasattr(WorksheetResultsFormat, input.get('fileFormat').upper()):
raise exceptions.InvalidInput(
'fileFormat', input.get('fileFormat'),
', '.join(result_format.value for result_format in WorksheetResultsFormat))

with context.engine.scoped_session() as session:
return WorksheetQueryResultService.download_sql_query_result(session=session, data=input)
11 changes: 6 additions & 5 deletions backend/dataall/modules/worksheets/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,16 @@
name='WorksheetQueryResult',
fields=[
gql.Field(name='worksheetQueryResultUri', type=gql.ID),
gql.Field(name='queryType', type=gql.NonNullableType(gql.String)),
gql.Field(name='sqlBody', type=gql.NonNullableType(gql.String)),
gql.Field(name='sqlBody', type=gql.String),
gql.Field(name='AthenaQueryId', type=gql.NonNullableType(gql.String)),
gql.Field(name='region', type=gql.NonNullableType(gql.String)),
gql.Field(name='AwsAccountId', type=gql.NonNullableType(gql.String)),
gql.Field(name='AthenaOutputBucketName', type=gql.NonNullableType(gql.String)),
gql.Field(name='AthenaOutputKey', type=gql.NonNullableType(gql.String)),
gql.Field(name='timeElapsedInSecond', type=gql.NonNullableType(gql.Integer)),
gql.Field(name='elapsedTimeInMs', type=gql.Integer),
gql.Field(name='created', type=gql.NonNullableType(gql.String)),
gql.Field(name='downloadLink', type=gql.String),
gql.Field(name='outputLocation', type=gql.String),
gql.Field(name='expiresIn', type=gql.AWSDateTime),
gql.Field(name='fileFormat', type=gql.String),
],
)

Expand Down
73 changes: 73 additions & 0 deletions backend/dataall/modules/worksheets/aws/s3_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import boto3
from botocore.config import Config
dlpzx marked this conversation as resolved.
Show resolved Hide resolved

from botocore.exceptions import ClientError
import logging
from dataall.base.db.exceptions import AWSResourceNotFound
from dataall.base.aws.sts import SessionHelper

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dataall.core.environment.db.environment_models import Environment
try:
from mypy_boto3_s3 import S3Client as S3ClientType
except ImportError:
S3ClientType = None

log = logging.getLogger(__name__)


class S3Client:

dlpzx marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, env: 'Environment'):
self._client = SessionHelper.remote_session(env.AwsAccountId, env.region).client('s3', region_name=env.region)
self._env = env

@property
def client(self) -> 'S3ClientType':
return self._client

dlpzx marked this conversation as resolved.
Show resolved Hide resolved
def get_presigned_url(self, bucket, key, expire_minutes: int = 15):
expire_seconds = expire_minutes * 60
try:
presigned_url = self.client.generate_presigned_url(
'get_object',
Params=dict(
Bucket=bucket,
Key=key,
),
ExpiresIn=expire_seconds,
)
return presigned_url
except ClientError as e:
log.error(f'Failed to get presigned URL due to: {e}')
raise e

def object_exists(self, bucket, key) -> bool:
try:
self.client.head_object(Bucket=bucket, Key=key)
return True
except ClientError as e:
if e.response['Error']['Code'] == '404':
log.info(f'Object {key} not found in bucket {bucket}')
return False
log.error(f'Failed to check object existence due to: {e}')
raise AWSResourceNotFound('s3_object_exists', f'Object {key} not found in bucket {bucket}')


def put_object(self, bucket, key, body):
try:
self.client.put_object(Bucket=bucket, Key=key, Body=body)
except ClientError as e:
log.error(f'Failed to put object due to: {e}')
raise e


def get_object(self, bucket, key) -> str:
try:
response = self.client.get_object(Bucket=bucket, Key=key)
return response['Body'].read().decode('utf-8')
except ClientError as e:
log.error(f'Failed to get object due to: {e}')
raise e
21 changes: 15 additions & 6 deletions backend/dataall/modules/worksheets/db/worksheet_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
import enum

from sqlalchemy import Column, DateTime, Integer, Enum, String
from future.backports.email.policy import default
from sqlalchemy import Column, DateTime, Integer, Enum, String, BigInteger
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import query_expression

Expand All @@ -27,15 +28,23 @@ class Worksheet(Resource, Base):

class WorksheetQueryResult(Base):
__tablename__ = 'worksheet_query_result'
worksheetQueryResultUri = Column(String, primary_key=True, default=utils.uuid('worksheetQueryResultUri'))
worksheetUri = Column(String, nullable=False)
AthenaQueryId = Column(String, primary_key=True)
status = Column(String, nullable=False)
queryType = Column(Enum(QueryType), nullable=False, default=True)
sqlBody = Column(String, nullable=False)
AthenaQueryId = Column(String, nullable=False)
status = Column(String, nullable=True)
sqlBody = Column(String, nullable=True)
AwsAccountId = Column(String, nullable=False)
region = Column(String, nullable=False)
OutputLocation = Column(String, nullable=False)
error = Column(String, nullable=True)
ElapsedTimeInMs = Column(Integer, nullable=True)
DataScannedInBytes = Column(Integer, nullable=True)
DataScannedInBytes = Column(BigInteger, nullable=True)
created = Column(DateTime, default=datetime.datetime.now)

downloadLink = Column(String, nullable=True)
expiresIn = Column(DateTime, nullable=True)
updated = Column(DateTime, nullable=False, onupdate=datetime.datetime.utcnow, default=datetime.datetime.utcnow)
fileFormat = Column(String, nullable=True)

def is_download_link_expired(self):
return self.expiresIn is None or self.expiresIn <= datetime.datetime.utcnow()
14 changes: 14 additions & 0 deletions backend/dataall/modules/worksheets/db/worksheet_repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,17 @@ def paginated_user_worksheets(session, username, groups, uri, data=None, check_p
page=data.get('page', WorksheetRepository._DEFAULT_PAGE),
page_size=data.get('pageSize', WorksheetRepository._DEFAULT_PAGE_SIZE),
).to_dict()

@staticmethod
def find_query_result_by_format(
session, worksheet_uri: str, athena_query_id: str, file_format: str
) -> WorksheetQueryResult:
return (
session.query(WorksheetQueryResult)
.filter(
WorksheetQueryResult.worksheetUri == worksheet_uri,
WorksheetQueryResult.AthenaQueryId == athena_query_id,
WorksheetQueryResult.fileFormat == file_format,
)
.first()
)
dlpzx marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@ class WorksheetRole(GraphQLEnumMapper):
Creator = '950'
Admin = '900'
NoPermission = '000'


class WorksheetResultsFormat(GraphQLEnumMapper):
CSV = 'csv'
XLSX = 'xlsx'
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import csv
import io
import os
from datetime import datetime, timedelta, UTC as DATETIME_UTC
from typing import TYPE_CHECKING

from openpyxl import Workbook

from dataall.base.db import exceptions
from dataall.core.environment.services.environment_service import EnvironmentService
from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService
from dataall.modules.worksheets.aws.s3_client import S3Client
from dataall.modules.worksheets.db.worksheet_models import WorksheetQueryResult
from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository
from dataall.modules.worksheets.services.worksheet_enums import WorksheetResultsFormat
from dataall.modules.worksheets.services.worksheet_permissions import RUN_ATHENA_QUERY
from dataall.modules.worksheets.services.worksheet_service import WorksheetService

dlpzx marked this conversation as resolved.
Show resolved Hide resolved
if TYPE_CHECKING:
try:
from sqlalchemy.orm import Session
from openpyxl.worksheet.worksheet import Worksheet
except ImportError:
print('skipping type checks as stubs are not installed')
Session = None
Worksheet = None


class WorksheetQueryResultService:
_DEFAULT_ATHENA_QUERIES_PATH = 'athenaqueries'
_DEFAULT_QUERY_RESULTS_TIMEOUT = os.getenv('QUERY_RESULT_TIMEOUT_MINUTES', 120)

dlpzx marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def _create_query_result(
environment_bucket: str, athena_workgroup: str, worksheet_uri: str, region: str, aws_account_id: str, data: dict
) -> WorksheetQueryResult:
sql_query_result = WorksheetQueryResult(
worksheetUri=worksheet_uri,
AthenaQueryId=data.get('athenaQueryId'),
dlpzx marked this conversation as resolved.
Show resolved Hide resolved
fileFormat=data.get('fileFormat'),
OutputLocation=f's3://{environment_bucket}/athenaqueries/{athena_workgroup}/',
dlpzx marked this conversation as resolved.
Show resolved Hide resolved
region=region,
AwsAccountId=aws_account_id
)
return sql_query_result

dlpzx marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def build_s3_file_path(
workgroup: str, query_id: str, athena_queries_dir: str = None
) -> str:
athena_queries_dir = athena_queries_dir or WorksheetQueryResultService._DEFAULT_ATHENA_QUERIES_PATH
return f'{athena_queries_dir}/{workgroup}/{query_id}'

@staticmethod
def convert_csv_to_xlsx(csv_data) -> io.BytesIO:
wb = Workbook()
ws: 'Worksheet' = wb.active
csv_reader = csv.reader(csv_data.splitlines())
for row in csv_reader:
ws.append(row)

excel_buffer = io.BytesIO()
wb.save(excel_buffer)
excel_buffer.seek(0)
return excel_buffer

@staticmethod
@ResourcePolicyService.has_resource_permission(RUN_ATHENA_QUERY)
def download_sql_query_result(session: 'Session', data: dict = None):
environment = EnvironmentService.get_environment_by_uri(session, data.get('environmentUri'))
worksheet = WorksheetService.get_worksheet_by_uri(session, data.get('worksheetUri'))
env_group = EnvironmentService.get_environment_group(
session, worksheet.SamlAdminGroupName, environment.environmentUri
)
sql_query_result = WorksheetRepository.find_query_result_by_format(
session, data.get('worksheetUri'), data.get('athenaQueryId'), data.get('fileFormat')
)
s3_client = S3Client(environment)
if not sql_query_result:
sql_query_result = WorksheetQueryResultService._create_query_result(
environment.EnvironmentDefaultBucketName,
env_group.environmentAthenaWorkGroup,
worksheet.worksheetUri,
environment.region,
environment.AwsAccountId,
data,
)
output_file_s3_path = WorksheetQueryResultService.build_s3_file_path(
env_group.environmentAthenaWorkGroup, data.get('athenaQueryId')
)
if sql_query_result.fileFormat == WorksheetResultsFormat.XLSX.value:
try:
csv_data = s3_client.get_object(bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{WorksheetResultsFormat.CSV.value}')
excel_buffer = WorksheetQueryResultService.convert_csv_to_xlsx(csv_data)
s3_client.put_object(bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{WorksheetResultsFormat.XLSX.value}', body=excel_buffer)
except Exception as e:
raise exceptions.AWSResourceNotAvailable('CONVERT_CSV_TO_EXCEL',f'Failed to convert csv to xlsx: {e}')

s3_client.object_exists(
bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{sql_query_result.fileFormat}'
)
if sql_query_result.is_download_link_expired():
url = s3_client.get_presigned_url(
bucket=environment.EnvironmentDefaultBucketName,
key=f'{output_file_s3_path}.{sql_query_result.fileFormat}',
expire_minutes=WorksheetQueryResultService._DEFAULT_QUERY_RESULTS_TIMEOUT,
)
sql_query_result.downloadLink = url
sql_query_result.expiresIn = datetime.now(DATETIME_UTC) + timedelta(minutes=WorksheetQueryResultService._DEFAULT_QUERY_RESULTS_TIMEOUT)

session.add(sql_query_result)
session.commit()

return sql_query_result
dlpzx marked this conversation as resolved.
Show resolved Hide resolved
Loading