Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add correct pagination for filters #335

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions core/graph/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions core/graph/model/models_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions core/graph/schema.graphqls
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type Column {
type RowsResult {
Columns: [Column!]!
Rows: [[String!]!]!
TotalCount: Int!
DisableUpdate: Boolean!
}

Expand Down
1 change: 1 addition & 0 deletions core/graph/schema.resolvers.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions core/src/engine/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type Column struct {
type GetRowsResult struct {
Columns []Column
Rows [][]string
TotalCount int
DisableUpdate bool
}

Expand Down
61 changes: 61 additions & 0 deletions core/src/plugins/clickhouse/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package clickhouse

import (
"fmt"
"strings"

"github.com/clidey/whodb/core/src/common"
"github.com/clidey/whodb/core/src/engine"
"github.com/clidey/whodb/core/src/llm"
)

func (p *ClickHousePlugin) Chat(config *engine.PluginConfig, schema string, model string, previousConversation string, query string) ([]*engine.ChatMessage, error) {
db, err := DB(config)
if err != nil {
return nil, err
}

tableFields, err := getAllTableSchema(db, schema)
if err != nil {
return nil, err
}

tableDetails := strings.Builder{}
for tableName, fields := range tableFields {
tableDetails.WriteString(fmt.Sprintf("table: %v\n", tableName))
for _, field := range fields {
tableDetails.WriteString(fmt.Sprintf("- %v (%v)\n", field.Key, field.Value))
}
}

context := tableDetails.String()

completeQuery := fmt.Sprintf(common.RawSQLQueryPrompt, "Postgres", schema, context, previousConversation, query, "Postgres")

response, err := llm.Instance(config).Complete(completeQuery, llm.LLMModel(model), nil)
if err != nil {
return nil, err
}

chats := common.ExtractCodeFromResponse(*response)
chatMessages := []*engine.ChatMessage{}
for _, chat := range chats {
var result *engine.GetRowsResult
chatType := "message"
if chat.Type == "sql" {
rowResult, err := p.RawExecute(config, chat.Text)
if err != nil {
return nil, err
}
chatType = "sql"
result = rowResult
}
chatMessages = append(chatMessages, &engine.ChatMessage{
Type: chatType,
Result: result,
Text: chat.Text,
})
}

return chatMessages, nil
}
35 changes: 29 additions & 6 deletions core/src/plugins/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,35 @@
return storageUnits, nil
}

func getAllTableSchema(conn *sql.DB, schema string) (map[string][]engine.Record, error) {
query := fmt.Sprintf(`
SELECT
table,
name,
type
FROM system.columns
WHERE database = '%s'
ORDER BY table, position
`, schema)

Check failure

Code scanning / CodeQL

Database query built from user-controlled sources High

This query depends on a
user-provided value
.
This query depends on a
user-provided value
.
This query depends on a user-provided value.
This query depends on a user-provided value.

Copilot Autofix AI 5 days ago

To fix the problem, we should use parameterized queries or prepared statements to safely embed user input into SQL queries. This approach prevents SQL injection by ensuring that user input is treated as data rather than executable code.

In the provided code, we need to replace the use of fmt.Sprintf with parameterized queries using the QueryContext method. This involves modifying the getAllTableSchema and getTableSchema functions to use query parameters instead of directly embedding the schema and tableName values into the query string.

Suggested changeset 1
core/src/plugins/clickhouse/clickhouse.go

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/core/src/plugins/clickhouse/clickhouse.go b/core/src/plugins/clickhouse/clickhouse.go
--- a/core/src/plugins/clickhouse/clickhouse.go
+++ b/core/src/plugins/clickhouse/clickhouse.go
@@ -105,3 +105,3 @@
 func getAllTableSchema(conn *sql.DB, schema string) (map[string][]engine.Record, error) {
-	query := fmt.Sprintf(`
+	query := `
 		SELECT 
@@ -111,7 +111,7 @@
 		FROM system.columns
-		WHERE database = '%s'
+		WHERE database = ?
 		ORDER BY table, position
-	`, schema)
+	`
 
-	rows, err := conn.QueryContext(context.Background(), query)
+	rows, err := conn.QueryContext(context.Background(), query, schema)
 	if err != nil {
@@ -134,3 +134,3 @@
 func getTableSchema(conn *sql.DB, schema string, tableName string) ([]engine.Record, error) {
-	query := fmt.Sprintf(`
+	query := `
 		SELECT 
@@ -139,7 +139,7 @@
 		FROM system.columns
-		WHERE database = '%s' AND table = '%s'
+		WHERE database = ? AND table = ?
 		ORDER BY position
-	`, schema, tableName)
+	`
 
-	rows, err := conn.QueryContext(context.Background(), query)
+	rows, err := conn.QueryContext(context.Background(), query, schema, tableName)
 	if err != nil {
EOF
@@ -105,3 +105,3 @@
func getAllTableSchema(conn *sql.DB, schema string) (map[string][]engine.Record, error) {
query := fmt.Sprintf(`
query := `
SELECT
@@ -111,7 +111,7 @@
FROM system.columns
WHERE database = '%s'
WHERE database = ?
ORDER BY table, position
`, schema)
`

rows, err := conn.QueryContext(context.Background(), query)
rows, err := conn.QueryContext(context.Background(), query, schema)
if err != nil {
@@ -134,3 +134,3 @@
func getTableSchema(conn *sql.DB, schema string, tableName string) ([]engine.Record, error) {
query := fmt.Sprintf(`
query := `
SELECT
@@ -139,7 +139,7 @@
FROM system.columns
WHERE database = '%s' AND table = '%s'
WHERE database = ? AND table = ?
ORDER BY position
`, schema, tableName)
`

rows, err := conn.QueryContext(context.Background(), query)
rows, err := conn.QueryContext(context.Background(), query, schema, tableName)
if err != nil {
Copilot is powered by AI and may make mistakes. Always verify output.
Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
rows, err := conn.QueryContext(context.Background(), query)
if err != nil {
return nil, err
}
defer rows.Close()

tableColumnsMap := make(map[string][]engine.Record)
for rows.Next() {
var tableName, columnName, dataType string
if err := rows.Scan(&tableName, &columnName, &dataType); err != nil {
return nil, err
}
tableColumnsMap[tableName] = append(tableColumnsMap[tableName], engine.Record{Key: columnName, Value: dataType})
}

return tableColumnsMap, nil
}

func getTableSchema(conn *sql.DB, schema string, tableName string) ([]engine.Record, error) {
query := fmt.Sprintf(`
SELECT
Expand Down Expand Up @@ -130,12 +159,6 @@
return result, nil
}

func (p *ClickHousePlugin) Chat(config *engine.PluginConfig, schema string, model string, previousConversation string, query string) ([]*engine.ChatMessage, error) {
// Implement chat functionality similar to MySQL implementation
// You may need to adapt this based on ClickHouse specifics
return nil, fmt.Errorf("chat functionality not implemented for ClickHouse")
}

func NewClickHousePlugin() *engine.Plugin {
return &engine.Plugin{
Type: engine.DatabaseType_ClickHouse,
Expand Down
14 changes: 10 additions & 4 deletions core/src/plugins/clickhouse/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,23 @@ import (
"context"
"database/sql"
"fmt"

"github.com/clidey/whodb/core/src/engine"
)

func (p *ClickHousePlugin) GetRows(config *engine.PluginConfig, schema string, storageUnit string, where string, pageSize int, pageOffset int) (*engine.GetRowsResult, error) {
query := fmt.Sprintf("SELECT * FROM %s.%s", schema, storageUnit)
baseQuery := fmt.Sprintf("FROM %s.%s", schema, storageUnit)
if where != "" {
query += " WHERE " + where
baseQuery += " WHERE " + where
}
query += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, pageOffset)
query := fmt.Sprintf("SELECT * %s LIMIT %d OFFSET %d", baseQuery, pageSize, pageOffset)

return p.executeQuery(config, query)
result, err := p.executeQuery(config, query)
if err != nil {
return nil, err
}

return result, nil
}

func (p *ClickHousePlugin) RawExecute(config *engine.PluginConfig, query string) (*engine.GetRowsResult, error) {
Expand Down
12 changes: 8 additions & 4 deletions core/src/plugins/elasticsearch/elasticsearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,15 @@ func (p *ElasticSearchPlugin) GetRows(config *engine.PluginConfig, database, col
query[key] = value
}

var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(query); err != nil {
var searchBuf bytes.Buffer
if err := json.NewEncoder(&searchBuf).Encode(query); err != nil {
return nil, err
}

res, err := client.Search(
client.Search.WithContext(context.Background()),
client.Search.WithIndex(collection),
client.Search.WithBody(&buf),
client.Search.WithBody(&searchBuf),
client.Search.WithTrackTotalHits(true),
)
if err != nil {
Expand All @@ -122,7 +122,10 @@ func (p *ElasticSearchPlugin) GetRows(config *engine.PluginConfig, database, col
return nil, err
}

hits := searchResult["hits"].(map[string]interface{})["hits"].([]interface{})
hitsInfo := searchResult["hits"].(map[string]interface{})
totalHits := int(hitsInfo["total"].(map[string]interface{})["value"].(float64))

hits := hitsInfo["hits"].([]interface{})
result := &engine.GetRowsResult{
Columns: []engine.Column{
{Name: "document", Type: "Document"},
Expand All @@ -142,6 +145,7 @@ func (p *ElasticSearchPlugin) GetRows(config *engine.PluginConfig, database, col
result.Rows = append(result.Rows, []string{string(jsonBytes)})
}

result.TotalCount = totalHits
return result, nil
}

Expand Down
6 changes: 6 additions & 0 deletions core/src/plugins/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ func (p *MongoDBPlugin) GetRows(config *engine.PluginConfig, database, collectio
}
}

totalCount, err := coll.CountDocuments(context.TODO(), bsonFilter)
if err != nil {
return nil, err
}

findOptions := options.Find()
findOptions.SetLimit(int64(pageSize))
findOptions.SetSkip(int64(pageOffset))
Expand Down Expand Up @@ -124,6 +129,7 @@ func (p *MongoDBPlugin) GetRows(config *engine.PluginConfig, database, collectio
})
}

result.TotalCount = int(totalCount)
return result, nil
}

Expand Down
28 changes: 24 additions & 4 deletions core/src/plugins/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,32 @@
}

func (p *MySQLPlugin) GetRows(config *engine.PluginConfig, schema string, storageUnit string, where string, pageSize int, pageOffset int) (*engine.GetRowsResult, error) {
query := fmt.Sprintf("SELECT * FROM `%v`.`%s`", schema, storageUnit)
db, err := DB(config)
if err != nil {
return nil, err
}

baseQuery := fmt.Sprintf("FROM `%v`.`%s`", schema, storageUnit)
if len(where) > 0 {
query = fmt.Sprintf("%v WHERE %v", query, where)
baseQuery = fmt.Sprintf("%v WHERE %v", baseQuery, where)
}

countQuery := fmt.Sprintf("SELECT COUNT(*) %v", baseQuery)
var totalCount int
err = db.Raw(countQuery).Scan(&totalCount).Error

Check failure

Code scanning / CodeQL

Database query built from user-controlled sources High

This query depends on a
user-provided value
.
This query depends on a
user-provided value
.
This query depends on a
user-provided value
.
This query depends on a
user-provided value
.
This query depends on a
user-provided value
.
This query depends on a
user-provided value
.

Copilot Autofix AI 5 days ago

To fix the problem, we need to replace the unsafe construction of SQL queries using fmt.Sprintf with parameterized queries. This involves using placeholders in the SQL query and passing the user-provided values as separate arguments to the query execution function. This approach ensures that the user-provided values are properly escaped and prevents SQL injection attacks.

In the core/src/plugins/mysql/mysql.go file, we need to:

  1. Replace the fmt.Sprintf calls with parameterized queries.
  2. Pass the user-provided values as arguments to the db.Raw function.
Suggested changeset 1
core/src/plugins/mysql/mysql.go

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/core/src/plugins/mysql/mysql.go b/core/src/plugins/mysql/mysql.go
--- a/core/src/plugins/mysql/mysql.go
+++ b/core/src/plugins/mysql/mysql.go
@@ -119,10 +119,10 @@
 
-	query := fmt.Sprintf(`
+	query := `
 		SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE
 		FROM INFORMATION_SCHEMA.COLUMNS
-		WHERE TABLE_SCHEMA = '%v'
+		WHERE TABLE_SCHEMA = ?
 		ORDER BY TABLE_NAME, ORDINAL_POSITION
-	`, schema)
+	`
 
-	if err := db.Raw(query).Scan(&result).Error; err != nil {
+	if err := db.Raw(query, schema).Scan(&result).Error; err != nil {
 		return nil, err
@@ -144,10 +144,11 @@
 
-	baseQuery := fmt.Sprintf("FROM `%v`.`%s`", schema, storageUnit)
+	baseQuery := "FROM `?`.`?`"
+	args := []interface{}{schema, storageUnit}
 	if len(where) > 0 {
-		baseQuery = fmt.Sprintf("%v WHERE %v", baseQuery, where)
+		baseQuery += " WHERE " + where
 	}
 
-	countQuery := fmt.Sprintf("SELECT COUNT(*) %v", baseQuery)
+	countQuery := "SELECT COUNT(*) " + baseQuery
 	var totalCount int
-	err = db.Raw(countQuery).Scan(&totalCount).Error
+	err = db.Raw(countQuery, args...).Scan(&totalCount).Error
 	if err != nil {
@@ -156,5 +157,6 @@
 
-	query := fmt.Sprintf("SELECT * %v LIMIT ? OFFSET ?", baseQuery)
+	query := "SELECT * " + baseQuery + " LIMIT ? OFFSET ?"
+	args = append(args, pageSize, pageOffset)
 
-	result, err := p.executeRawSQL(config, query, pageSize, pageOffset)
+	result, err := p.executeRawSQL(config, query, args...)
 	if err != nil {
EOF
@@ -119,10 +119,10 @@

query := fmt.Sprintf(`
query := `
SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = '%v'
WHERE TABLE_SCHEMA = ?
ORDER BY TABLE_NAME, ORDINAL_POSITION
`, schema)
`

if err := db.Raw(query).Scan(&result).Error; err != nil {
if err := db.Raw(query, schema).Scan(&result).Error; err != nil {
return nil, err
@@ -144,10 +144,11 @@

baseQuery := fmt.Sprintf("FROM `%v`.`%s`", schema, storageUnit)
baseQuery := "FROM `?`.`?`"
args := []interface{}{schema, storageUnit}
if len(where) > 0 {
baseQuery = fmt.Sprintf("%v WHERE %v", baseQuery, where)
baseQuery += " WHERE " + where
}

countQuery := fmt.Sprintf("SELECT COUNT(*) %v", baseQuery)
countQuery := "SELECT COUNT(*) " + baseQuery
var totalCount int
err = db.Raw(countQuery).Scan(&totalCount).Error
err = db.Raw(countQuery, args...).Scan(&totalCount).Error
if err != nil {
@@ -156,5 +157,6 @@

query := fmt.Sprintf("SELECT * %v LIMIT ? OFFSET ?", baseQuery)
query := "SELECT * " + baseQuery + " LIMIT ? OFFSET ?"
args = append(args, pageSize, pageOffset)

result, err := p.executeRawSQL(config, query, pageSize, pageOffset)
result, err := p.executeRawSQL(config, query, args...)
if err != nil {
Copilot is powered by AI and may make mistakes. Always verify output.
Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
if err != nil {
return nil, err
}

query := fmt.Sprintf("SELECT * %v LIMIT ? OFFSET ?", baseQuery)

result, err := p.executeRawSQL(config, query, pageSize, pageOffset)
if err != nil {
return nil, err
}
query = fmt.Sprintf("%v LIMIT ? OFFSET ?", query)
return p.executeRawSQL(config, query, pageSize, pageOffset)

result.TotalCount = totalCount
return result, nil
}

func (p *MySQLPlugin) executeRawSQL(config *engine.PluginConfig, query string, params ...interface{}) (*engine.GetRowsResult, error) {
Expand Down
Loading