Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TrinoStatementExecError #4842

Merged
merged 6 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ omit =
*manage.py
*celery.py
*configurator.py
*database.py
database.py
*feature_flags.py
*probe_server.py
*settings.py
Expand Down
95 changes: 81 additions & 14 deletions koku/koku/test_trino_db_utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import uuid

from django.test import TestCase
from jinjasql import JinjaSql
from trino.dbapi import Connection
from trino.exceptions import TrinoQueryError

from . import trino_database as trino_db
from api.iam.test.iam_test_case import FakeTrinoConn
from api.iam.test.iam_test_case import FakeTrinoCur
from api.iam.test.iam_test_case import IamTestCase
from koku.trino_database import connect
from koku.trino_database import executescript
from koku.trino_database import PreprocessStatementError
from koku.trino_database import TrinoStatementExecError


class TestTrinoDatabaseUtils(IamTestCase):
def test_connect(self):
"""
Test connection to trino returns trino.dbapi.Connection instance
"""
conn = trino_db.connect(schema=self.schema_name, catalog="hive")
conn = connect(schema=self.schema_name, catalog="hive")
self.assertTrue(isinstance(conn, Connection))
self.assertEqual(conn.schema, self.schema_name)
self.assertEqual(conn.catalog, "hive")
Expand Down Expand Up @@ -49,10 +54,10 @@ def test_executescript(self):
"int_data": 255,
"txt_data": "This is a test",
}
results = trino_db.executescript(conn, sqlscript, params=params, preprocessor=JinjaSql().prepare_query)
results = executescript(conn, sqlscript, params=params, preprocessor=JinjaSql().prepare_query)
self.assertEqual(results, [["eek"], ["eek"], ["eek"], ["eek"], ["eek"], ["eek"]])

def test_executescript_err(self):
def test_executescript_preprocessor_error(self):
"""
Test executescript will raise a preprocessor error
"""
Expand All @@ -61,22 +66,46 @@ def test_executescript_err(self):
select * from eek where val1 in {{val_list}};
"""
params = {"val_list": (1, 2, 3, 4, 5)}
with self.assertRaises(trino_db.PreprocessStatementError):
trino_db.executescript(conn, sqlscript, params=params, preprocessor=JinjaSql().prepare_query)
with self.assertRaises(PreprocessStatementError):
executescript(conn, sqlscript, params=params, preprocessor=JinjaSql().prepare_query)

def test_executescript_no_preprocess(self):
def test_executescript_no_preprocessor_error(self):
"""
Test executescript will raise a preprocessor error
Test executescript will not raise a preprocessor error
"""
sqlscript = """
select x from y;
select a from b;
"""
conn = FakeTrinoConn()
res = trino_db.executescript(conn, sqlscript)
res = executescript(conn, sqlscript)
self.assertEqual(res, [["eek"], ["eek"]])

def test_preprocessor_err(self):
def test_executescript_trino_error(self):
"""
Test that executescirpt will raise a TrinoStatementExecError
"""

class FakeTrinoConn:
@property
def cursor(self):
raise TrinoQueryError(
{
"errorName": "REMOTE_TASK_ERROR",
"errorType": "INTERNAL_ERROR",
"message": "Expected response code",
}
)

with (
self.assertRaisesRegex(TrinoStatementExecError, "type=INTERNAL_ERROR"),
self.assertLogs("koku.trino_database", level="WARN") as logger,
):
executescript(FakeTrinoConn(), "SELECT x from y")

self.assertIn("WARNING:koku.trino_database:Trino Query Error", logger.output[0])

def test_preprocessor_error(self):
def t_preprocessor(*args):
raise TypeError("This is a test")

Expand All @@ -86,8 +115,8 @@ def t_preprocessor(*args):
"""
params = {"eek": 1}
conn = FakeTrinoConn()
with self.assertRaises(trino_db.PreprocessStatementError):
trino_db.executescript(conn, sqlscript, params=params, preprocessor=t_preprocessor)
with self.assertRaises(PreprocessStatementError):
executescript(conn, sqlscript, params=params, preprocessor=t_preprocessor)

def test_executescript_error(self):
def t_exec_error(*args, **kwargs):
Expand All @@ -105,6 +134,44 @@ def cursor(self):
select x from y;
select a from b;
"""
with self.assertRaises(trino_db.TrinoStatementExecError):
with self.assertRaises(ValueError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick question: Curious why this is now a ValueError from TrinoStatementExecError?

Copy link
Contributor Author

@samdoran samdoran Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping someone would ask about that. I changed the behavior in executescript() to only raise TrinoStatementExecError if a TrinoQueryError was raised since that is now a required parameter of TrinoStatementExecError. Otherwise, just raise the original exception.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation 👍🏾 , and I see the ValueError raised by FakerFakeTrinoConn

conn = FakerFakeTrinoConn()
trino_db.executescript(conn, sqlscript)
executescript(conn, sqlscript)


class TestTrinoStatementExecError(TestCase):
def test_trino_statement_exec_error(self):
"""Test TestTrinoStatementExecError behavior"""

trino_query_error = TrinoQueryError(
{
"errorName": "REMOTE_TASK_ERROR",
"errorCode": 99,
"errorType": "INTERNAL_ERROR",
"failureInfo": {"type": "CLOUD_TROUBLES"},
"message": "Expected response code",
"errorLocation": {"lineNumber": 42, "columnNumber": 24},
},
query_id="20231220_165606_25626_c7r5y",
)

trino_statement_error = TrinoStatementExecError("SELECT x from y", 1, {}, trino_query_error)

expected_str = (
"Trino Query Error (TrinoQueryError) : TrinoQueryError(type=INTERNAL_ERROR, name=REMOTE_TASK_ERROR, "
'message="Expected response code", query_id=20231220_165606_25626_c7r5y) statement number 1\n'
"Statement: SELECT x from y\n"
"Parameters: {}"
)
expected_repr = (
"TrinoStatementExecError("
"type=INTERNAL_ERROR, "
"name=REMOTE_TASK_ERROR, "
"message=Expected response code, "
"query_id=20231220_165606_25626_c7r5y)"
)
self.assertEqual(str(trino_statement_error), expected_str)
self.assertEqual(repr(trino_statement_error), expected_repr)
self.assertEqual(trino_statement_error.error_code, 99)
self.assertEqual(trino_statement_error.error_exception, "CLOUD_TROUBLES")
self.assertEqual(trino_statement_error.failure_info, {"type": "CLOUD_TROUBLES"})
88 changes: 72 additions & 16 deletions koku/koku/trino_database.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
import os
import re
import typing as t

import sqlparse
import trino
from trino.exceptions import TrinoQueryError
from trino.transaction import IsolationLevel


Expand All @@ -19,7 +21,62 @@ class PreprocessStatementError(Exception):


class TrinoStatementExecError(Exception):
pass
def __init__(
self,
statement: str,
statement_number: int,
sql_params: dict[str, t.Any],
trino_error: t.Optional[TrinoQueryError] = None,
):
self.statement = statement
self.statement_number = statement_number
self.sql_params = sql_params
self._trino_error = trino_error

def __repr__(self):
return (
f"{self.__class__.__name__}("
f"type={self.error_type}, "
f"name={self.error_name}, "
f"message={self.message}, "
f"query_id={self.query_id})"
)

def __str__(self):
return (
f"Trino Query Error ({self._trino_error.__class__.__name__}) : "
f"{self._trino_error} statement number {self.statement_number}{os.linesep}"
f"Statement: {self.statement}{os.linesep}"
f"Parameters: {self.sql_params}"
)

@property
def error_code(self) -> t.Optional[int]:
return self._trino_error.error_code

@property
def error_name(self) -> t.Optional[str]:
return self._trino_error.error_name

@property
def error_type(self) -> t.Optional[str]:
return self._trino_error.error_type

@property
def error_exception(self) -> t.Optional[str]:
return self._trino_error.error_exception

@property
def failure_info(self) -> t.Optional[dict[str, t.Any]]:
return self._trino_error.failure_info

@property
def message(self) -> str:
return self._trino_error.message

@property
def query_id(self) -> t.Optional[str]:
return self._trino_error.query_id


def connect(**connect_args):
Expand Down Expand Up @@ -78,31 +135,30 @@ def executescript(trino_conn, sqlscript, *, params=None, preprocessor=None):
if preprocessor and params:
try:
stmt, s_params = preprocessor(p_stmt, params)
except Exception as e:
except Exception as exc:
LOG.warning(
f"Preprocessor Error ({e.__class__.__name__}) : {str(e)}{os.linesep}"
+ f"Statement template : {p_stmt}"
+ os.linesep
+ f"Parameters : {params}"
f"Preprocessor Error ({exc.__class__.__name__}) : {exc}{os.linesep}"
f"Statement template : {p_stmt}{os.linesep}"
f"Parameters : {params}"
)
exc_type = e.__class__.__name__
raise PreprocessStatementError(f"{exc_type} :: {e}") from e
exc_type = exc.__class__.__name__
raise PreprocessStatementError(f"{exc_type} :: {exc}") from exc
else:
stmt, s_params = p_stmt, params

try:
cur = trino_conn.cursor()
cur.execute(stmt, params=s_params)
results = cur.fetchall()
except Exception as e:
exc_msg = (
f"Trino Query Error ({e.__class__.__name__}) : {str(e)} statement number {stmt_num}{os.linesep}"
+ f"Statement: {stmt}"
+ os.linesep
+ f"Parameters: {s_params}"
except TrinoQueryError as trino_exc:
trino_statement_error = TrinoStatementExecError(
statement=stmt, statement_number=stmt_num, sql_params=s_params, trino_error=trino_exc
)
LOG.warning(exc_msg)
raise TrinoStatementExecError(exc_msg) from e
LOG.warning(f"{trino_statement_error!s}")
raise trino_statement_error from trino_exc
except Exception as exc:
LOG.warning(str(exc))
raise

all_results.extend(results)

Expand Down
10 changes: 5 additions & 5 deletions koku/masu/database/ocp_report_db_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,14 @@ def populate_line_item_daily_summary_table_trino(
try:
self._execute_trino_multipart_sql_query(sql, bind_params=sql_params)
except TrinoStatementExecError as trino_exc:
if "one or more partitions already exist" in str(trino_exc).lower():
if trino_exc.error_name == "ALREADY_EXISTS":
LOG.warning(
log_json(
ctx=self.extract_context_from_sql_params(sql_params),
msg=getattr(trino_exc.__cause__, "message", None),
error_type=getattr(trino_exc.__cause__, "error_type", None),
error_name=getattr(trino_exc.__cause__, "error_name", None),
query_id=getattr(trino_exc.__cause__, "query_id", None),
msg=trino_exc.message,
error_type=trino_exc.error_type,
error_name=trino_exc.error_name,
query_id=trino_exc.query_id,
)
)
else:
Expand Down
53 changes: 17 additions & 36 deletions koku/masu/test/database/test_ocp_report_db_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from django.db.models import Q
from django.db.models import Sum
from trino.exceptions import TrinoExternalError
from trino.exceptions import TrinoUserError

from api.iam.test.iam_test_case import FakeTrinoConn
from api.provider.models import Provider
Expand Down Expand Up @@ -105,8 +106,14 @@ def test_populate_line_item_daily_summary_table_trino_exception_warn(self, mock_
self.accessor as acc,
self.assertLogs("masu.database.ocp_report_db_accessor", level="WARN") as logger,
):
mock_sql_query.side_effect = TrinoStatementExecError(message)
mock_sql_query.side_effect.__cause__ = TrinoExternalError({"message": message})
trino_error = TrinoUserError(
{
"errorType": "USER_ERROR",
"errorName": "ALREADY_EXISTS",
"message": message,
}
)
mock_sql_query.side_effect = TrinoStatementExecError("SELECT * from table", 1, {}, trino_error)

acc.populate_line_item_daily_summary_table_trino(
start_date, end_date, report_period_id, cluster_id, cluster_alias, source
Expand All @@ -133,47 +140,21 @@ def test_populate_line_item_daily_summary_table_trino_exception(self, mock_table
patch.object(self.accessor, "delete_ocp_hive_partition_by_day"),
patch.object(self.accessor, "_execute_trino_multipart_sql_query") as mock_sql_query,
self.accessor as acc,
self.assertRaisesRegex(TrinoStatementExecError, "Some other reason"),
self.assertRaisesRegex(TrinoStatementExecError, "Something went wrong"),
):
mock_sql_query.side_effect = TrinoStatementExecError("Some other reason")

acc.populate_line_item_daily_summary_table_trino(
start_date, end_date, report_period_id, cluster_id, cluster_alias, source
trino_error = TrinoExternalError(
{
"errorType": "EXTERNAL",
"errorName": "SOME_EXTERNAL_PROBLEM",
"message": "Something went wrong",
}
)

@patch("masu.database.ocp_report_db_accessor.trino_table_exists", return_value=True)
def test_populate_line_item_daily_summary_table_trino_exception_other(self, mock_table_exists):
"""
Test that a warning is logged when a TrinoStatementExecError is raised because
a partion already exists and that no errors are encountered if an exception other
than a TrinoQueryError is the cause.
"""

start_date = self.dh.this_month_start
end_date = self.dh.next_month_start
cluster_id = "ocp-cluster"
cluster_alias = "OCP FTW"
report_period_id = 1
source = self.provider_uuid
message = "One or more Partitions Already exist"
with (
patch.object(self.accessor, "delete_ocp_hive_partition_by_day"),
patch.object(self.accessor, "_execute_trino_multipart_sql_query") as mock_sql_query,
self.accessor as acc,
self.assertLogs("masu.database.ocp_report_db_accessor", level="WARN") as logger,
):
mock_sql_query.side_effect = TrinoStatementExecError(message)
mock_sql_query.side_effect.__cause__ = ValueError("Some other exception type")
mock_sql_query.side_effect = TrinoStatementExecError("SELECT * from table", 1, {}, trino_error)

acc.populate_line_item_daily_summary_table_trino(
start_date, end_date, report_period_id, cluster_id, cluster_alias, source
)

self.assertIn(
"WARNING:masu.database.ocp_report_db_accessor:{'message': None",
logger.output[0],
)

@patch("masu.database.ocp_report_db_accessor.trino_table_exists")
@patch("masu.database.ocp_report_db_accessor.pkgutil.get_data")
@patch("masu.database.report_db_accessor_base.trino_db.connect")
Expand Down
Loading