From 21bdcb22d3babfccb448840557f9dd336ef515d6 Mon Sep 17 00:00:00 2001 From: Jaume Marhuenda Date: Sat, 9 Oct 2021 20:41:03 +0200 Subject: [PATCH] make sure calling GetStatus is safe (#164) --- hive.go | 49 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/hive.go b/hive.go index 97ae6e5..211b47c 100644 --- a/hive.go +++ b/hive.go @@ -358,8 +358,8 @@ func (c *Connection) Close() error { if err != nil { return err } - if !success(responseClose.GetStatus()) { - return errors.New("Error closing the session: " + responseClose.GetStatus().String()) + if !success(safeStatus(responseClose.GetStatus())) { + return errors.New("Error closing the session: " + safeStatus(responseClose.GetStatus()).String()) } return nil } @@ -534,11 +534,11 @@ func (c *Cursor) executeAsync(ctx context.Context, query string) { } return } - if !success(responseExecute.GetStatus()) { + if !success(safeStatus(responseExecute.GetStatus())) { c.Err = HiveError{ - error: errors.New("Error while executing query: " + responseExecute.GetStatus().String()), - Message: *responseExecute.GetStatus().ErrorMessage, - ErrorCode: int(*responseExecute.GetStatus().ErrorCode), + error: errors.New("Error while executing query: " + safeStatus(responseExecute.GetStatus()).String()), + Message: *safeStatus(responseExecute.GetStatus()).ErrorMessage, + ErrorCode: int(*safeStatus(responseExecute.GetStatus()).ErrorCode), } return } @@ -562,8 +562,8 @@ func (c *Cursor) Poll(getProgress bool) (status *hiveserver.TGetOperationStatusR if c.Err != nil { return nil } - if !success(responsePoll.GetStatus()) { - c.Err = errors.New("Error closing the operation: " + responsePoll.GetStatus().String()) + if !success(safeStatus(responsePoll.GetStatus())) { + c.Err = errors.New("Error closing the operation: " + safeStatus(responsePoll.GetStatus()).String()) return nil } return responsePoll @@ -960,7 +960,7 @@ func (c *Cursor) Description() [][]string { return nil } if metaResponse.Status.StatusCode != hiveserver.TStatusCode_SUCCESS_STATUS { - c.Err = errors.New(metaResponse.GetStatus().String()) + c.Err = errors.New(safeStatus(metaResponse.GetStatus()).String()) return nil } m := make([][]string, len(metaResponse.Schema.Columns)) @@ -1019,8 +1019,8 @@ func (c *Cursor) pollUntilData(ctx context.Context, n int) (err error) { } c.response = responseFetch - if responseFetch.GetStatus().StatusCode != hiveserver.TStatusCode_SUCCESS_STATUS { - rowsAvailable <- errors.New(responseFetch.GetStatus().String()) + if safeStatus(responseFetch.GetStatus()).StatusCode != hiveserver.TStatusCode_SUCCESS_STATUS { + rowsAvailable <- errors.New(safeStatus(responseFetch.GetStatus()).String()) return } err = c.parseResults(responseFetch) @@ -1071,8 +1071,8 @@ func (c *Cursor) Cancel() { if c.Err != nil { return } - if !success(responseCancel.GetStatus()) { - c.Err = errors.New("Error closing the operation: " + responseCancel.GetStatus().String()) + if !success(safeStatus(responseCancel.GetStatus())) { + c.Err = errors.New("Error closing the operation: " + safeStatus(responseCancel.GetStatus()).String()) } return } @@ -1100,8 +1100,8 @@ func (c *Cursor) resetState() error { if err != nil { return err } - if !success(responseClose.GetStatus()) { - return errors.New("Error closing the operation: " + responseClose.GetStatus().String()) + if !success(safeStatus(responseClose.GetStatus())) { + return errors.New("Error closing the operation: " + safeStatus(responseClose.GetStatus()).String()) } return nil } @@ -1178,3 +1178,22 @@ func newCookieJar() inMemoryCookieJar { f := false return inMemoryCookieJar{&f, storage} } + +func safeStatus(status *hiveserver.TStatus) *hiveserver.TStatus { + if (status == nil) { + return &DEFAULT_STATUS + } + return status +} + +var DEFAULT_SQL_STATE = "" +var DEFAULT_ERROR_CODE = int32(-1) +var DEFAULT_ERROR_MESSAGE = "unknown error" +var DEFAULT_STATUS = hiveserver.TStatus { + StatusCode: hiveserver.TStatusCode_ERROR_STATUS, + InfoMessages: nil, + SqlState: &DEFAULT_SQL_STATE, + ErrorCode: &DEFAULT_ERROR_CODE, + ErrorMessage: &DEFAULT_ERROR_MESSAGE, +} +