diff --git a/Makefile b/Makefile index 578225b..efac8aa 100644 --- a/Makefile +++ b/Makefile @@ -14,6 +14,7 @@ install-deps: go get github.com/go-zookeeper/zk go get github.com/sirupsen/logrus go get github.com/stretchr/testify/assert + go get github.com/stretchr/testify/mock .PHONY: setup setup: install-covertools install-deps diff --git a/go.mod b/go.mod index 79b2a81..1aca8bf 100644 --- a/go.mod +++ b/go.mod @@ -2,13 +2,14 @@ module github.com/startreedata/pinot-client-go require ( github.com/go-zookeeper/zk v1.0.3 - github.com/sirupsen/logrus v1.9.0 - github.com/stretchr/testify v1.8.0 + github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.9.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f8ec044..ef7bc4b 100644 --- a/go.sum +++ b/go.sum @@ -7,12 +7,18 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/pinot/connection.go b/pinot/connection.go index 7f3976c..1620d4a 100644 --- a/pinot/connection.go +++ b/pinot/connection.go @@ -1,6 +1,11 @@ package pinot import ( + "fmt" + "math/big" + "strings" + "time" + log "github.com/sirupsen/logrus" ) @@ -37,6 +42,62 @@ func (c *Connection) ExecuteSQL(table string, query string) (*BrokerResponse, er return brokerResp, err } +// ExecuteSQLWithParams executes an SQL query with parameters for a given table +func (c *Connection) ExecuteSQLWithParams(table string, queryPattern string, params []interface{}) (*BrokerResponse, error) { + query, err := formatQuery(queryPattern, params) + if err != nil { + log.Errorf("Failed to format query: %v\n", err) + return nil, fmt.Errorf("failed to format query: %v", err) + } + return c.ExecuteSQL(table, query) +} + +func formatQuery(queryPattern string, params []interface{}) (string, error) { + // Count the number of placeholders in queryPattern + numPlaceholders := strings.Count(queryPattern, "?") + if numPlaceholders != len(params) { + return "", fmt.Errorf("number of placeholders in queryPattern (%d) does not match number of params (%d)", numPlaceholders, len(params)) + } + + // Split the query by '?' and incrementally build the new query + parts := strings.Split(queryPattern, "?") + + var newQuery strings.Builder + for i, part := range parts[:len(parts)-1] { + newQuery.WriteString(part) + formattedParam, err := formatArg(params[i]) + if err != nil { + log.Errorf("Failed to format parameter: %v\n", err) + return "", fmt.Errorf("failed to format parameter: %v", err) + } + newQuery.WriteString(formattedParam) + } + // Add the last part of the query, which does not follow a '?' + newQuery.WriteString(parts[len(parts)-1]) + return newQuery.String(), nil +} + +func formatArg(value interface{}) (string, error) { + switch v := value.(type) { + case string, *big.Int, *big.Float: + // For pinot types - STRING, BIG_DECIMAL and BYTES - enclose in single quotes + return fmt.Sprintf("'%v'", v), nil + case []byte: + // For pinot type - BYTES - convert to Hex string and enclose in single quotes + hexString := fmt.Sprintf("%x", v) + return fmt.Sprintf("'%s'", hexString), nil + case time.Time: + // For pinot type - TIMESTAMP - convert to ISO8601 format and enclose in single quotes + return fmt.Sprintf("'%s'", v.Format("2006-01-02 15:04:05.000")), nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: + // For types - INT, LONG, FLOAT, DOUBLE and BOOLEAN use as-is + return fmt.Sprintf("%v", v), nil + default: + // Throw error for unsupported types + return "", fmt.Errorf("unsupported type: %T", v) + } +} + // OpenTrace for the connection func (c *Connection) OpenTrace() { c.trace = true diff --git a/pinot/connection_test.go b/pinot/connection_test.go index fbc2dec..7f0ee60 100644 --- a/pinot/connection_test.go +++ b/pinot/connection_test.go @@ -3,6 +3,7 @@ package pinot import ( "encoding/json" "fmt" + "math/big" "net/http" "net/http/httptest" "strings" @@ -10,6 +11,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func TestSendingSQLWithMockServer(t *testing.T) { @@ -152,3 +154,189 @@ func TestSendingQueryWithTraceClose(t *testing.T) { assert.Nil(t, resp) assert.NotNil(t, err) } + +func TestFormatQuery(t *testing.T) { + // Test case 1: No parameters + queryPattern := "SELECT * FROM table" + expectedQuery := "SELECT * FROM table" + actualQuery, err := formatQuery(queryPattern, nil) + assert.Nil(t, err) + assert.Equal(t, expectedQuery, actualQuery) + + // Test case 2: Single parameter + queryPattern = "SELECT * FROM table WHERE id = ?" + params := []interface{}{42} + expectedQuery = "SELECT * FROM table WHERE id = 42" + actualQuery, err = formatQuery(queryPattern, params) + assert.Nil(t, err) + assert.Equal(t, expectedQuery, actualQuery) + + // Test case 3: Multiple parameters + queryPattern = "SELECT * FROM table WHERE id = ? AND name = ?" + params = []interface{}{42, "John"} + expectedQuery = "SELECT * FROM table WHERE id = 42 AND name = 'John'" + actualQuery, err = formatQuery(queryPattern, params) + assert.Nil(t, err) + assert.Equal(t, expectedQuery, actualQuery) + + // Test case 4: Invalid query pattern + queryPattern = "SELECT * FROM table WHERE id = ? AND name = ?" + params = []interface{}{42} // Missing second parameter + expectedQuery = "" // Empty query + actualQuery, err = formatQuery(queryPattern, params) + assert.NotNil(t, err) + assert.Equal(t, expectedQuery, actualQuery) +} + +func TestFormatArg(t *testing.T) { + // Test case 1: string value + value1 := "hello" + expected1 := "'hello'" + actual1, err := formatArg(value1) + assert.Nil(t, err) + assert.Equal(t, expected1, actual1) + + // Test case 2: time.Time value + value2 := time.Date(2022, time.January, 1, 12, 0, 0, 0, time.UTC) + expected2 := "'2022-01-01 12:00:00.000'" + actual2, err := formatArg(value2) + assert.Nil(t, err) + assert.Equal(t, expected2, actual2) + + // Test case 3: int value + value3 := 42 + expected3 := "42" + actual3, err := formatArg(value3) + assert.Nil(t, err) + assert.Equal(t, expected3, actual3) + + // Test case 4: big.Int value + value4 := big.NewInt(1234567890) + expected4 := "'1234567890'" + actual4, err := formatArg(value4) + assert.Nil(t, err) + assert.Equal(t, expected4, actual4) + + // Test case 5: float32 value + value5 := float32(3.14) + expected5 := "3.14" + actual5, err := formatArg(value5) + assert.Nil(t, err) + assert.Equal(t, expected5, actual5) + + // Test case 6: float64 value + value6 := float64(3.14159) + expected6 := "3.14159" + actual6, err := formatArg(value6) + assert.Nil(t, err) + assert.Equal(t, expected6, actual6) + + // Test case 7: bool value + value7 := true + expected7 := "true" + actual7, err := formatArg(value7) + assert.Nil(t, err) + assert.Equal(t, expected7, actual7) + + // Test case 8: unsupported type + value8 := struct{}{} + expected8 := "unsupported type: struct {}" + _, err = formatArg(value8) + assert.NotNil(t, err) + assert.Equal(t, expected8, err.Error()) + + // Test case 9: big.Float value + value9 := big.NewFloat(3.141592653589793238) + expected9 := "'3.141592653589793'" + actual9, err := formatArg(value9) + assert.Nil(t, err) + assert.Equal(t, expected9, actual9) + + // Test case 10: byte array value + value10 := []byte{0x48, 0x65, 0x6c, 0x6c, 0x6f} + expected10 := "'48656c6c6f'" + actual10, err := formatArg(value10) + assert.Nil(t, err) + assert.Equal(t, expected10, actual10) +} + +type mockBrokerSelector struct { + mock.Mock +} + +func (m *mockBrokerSelector) init() error { return nil } +func (m *mockBrokerSelector) selectBroker(table string) (string, error) { + args := m.Called(table) + return args.Get(0).(string), args.Error(1) +} + +type mockTransport struct { + mock.Mock +} + +func (m *mockTransport) execute(brokerAddress string, query *Request) (*BrokerResponse, error) { + args := m.Called(brokerAddress, query) + return args.Get(0).(*BrokerResponse), args.Error(1) +} + +func TestExecuteSQLWithParams(t *testing.T) { + mockBrokerSelector := &mockBrokerSelector{} + mockTransport := &mockTransport{} + + // Create Connection with mock brokerSelector and transport + conn := &Connection{ + brokerSelector: mockBrokerSelector, + transport: mockTransport, + } + + // Test case 1: Successful execution + mockBrokerSelector.On("selectBroker", "baseballStats").Return("host1:8000", nil) + mockTransport.On("execute", "host1:8000", mock.Anything).Return(&BrokerResponse{}, nil) + + queryPattern := "SELECT * FROM table WHERE id = ?" + params := []interface{}{42} + expectedQuery := "SELECT * FROM table WHERE id = 42" + expectedBrokerResp := &BrokerResponse{} + mockTransport.On("execute", "host1:8000", &Request{ + queryFormat: "sql", + query: expectedQuery, + trace: false, + useMultistageEngine: false, + }).Return(expectedBrokerResp, nil) + + brokerResp, err := conn.ExecuteSQLWithParams("baseballStats", queryPattern, params) + + assert.Nil(t, err) + assert.Equal(t, expectedBrokerResp, brokerResp) + + // Test case 2: Error in selecting broker + mockBrokerSelector.On("selectBroker", "baseballStats2").Return("", fmt.Errorf("error selecting broker")) + + _, err = conn.ExecuteSQLWithParams("baseballStats2", queryPattern, params) + + assert.NotNil(t, err) + assert.EqualError(t, err, "error selecting broker") + + // Test case 3: Error in formatting query + mockBrokerSelector.On("selectBroker", "baseballStats3").Return("host2:8000", nil) + mockTransport.On("execute", "host2:8000", mock.Anything).Return(&BrokerResponse{}, fmt.Errorf("error executing query")) + + _, err = conn.ExecuteSQLWithParams("baseballStats3", queryPattern, params) + + assert.NotNil(t, err) + assert.EqualError(t, err, "error executing query") + + // Test case 4: Error in formatting query with mismatched number of parameters + queryPattern = "SELECT * FROM table WHERE id = ? AND name = ?" + params = []interface{}{42} // Missing second parameter + _, err = conn.ExecuteSQLWithParams("baseballStats", queryPattern, params) + assert.NotNil(t, err) + assert.EqualError(t, err, "failed to format query: number of placeholders in queryPattern (2) does not match number of params (1)") + + // Test case 5: Unsupported argument type + queryPattern = "SELECT * FROM table WHERE id = ?" + params = []interface{}{struct{}{}} + _, err = conn.ExecuteSQLWithParams("baseballStats", queryPattern, params) + assert.NotNil(t, err) + assert.EqualError(t, err, "failed to format query: failed to format parameter: unsupported type: struct {}") +}