Skip to content

Commit

Permalink
Allow server specified default port to be used + test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
VigneshSelvaraj96 committed Oct 31, 2024
1 parent 94277ec commit a92d36e
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 8 deletions.
31 changes: 31 additions & 0 deletions pkg/granted/rds/local_port.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package rds

type getLocalPortInput struct {
// OverrideFlag is set by the user using the --port flag
OverrideFlag int
// DefaultFromServer is the port number specified by admins in the Terraform provider
DefaultFromServer int
// Fallback is the port to default to if OverrideFlag and DefaultFromServer are not set
Fallback int
}

// getLocalPort returns the port number to use for the local port
//
// Common Fate allows admins to set default ports in the Terraform provider and
// users to override them with the --port flag when running granted rds proxy --port <port>
//
// The order of priorities is:
// 1. OverrideFlag
// 2. DefaultFromServer
// 3. Fallback
//
// You should set Fallback to 5432 for PostgreSQL and 3306 for MySQL
func getLocalPort(input getLocalPortInput) int {
if input.OverrideFlag != 0 {
return input.OverrideFlag
}
if input.DefaultFromServer != 0 {
return input.DefaultFromServer
}
return input.Fallback
}
56 changes: 56 additions & 0 deletions pkg/granted/rds/local_port_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package rds

import "testing"

func Test_getLocalPort(t *testing.T) {
type args struct {
input getLocalPortInput
}
tests := []struct {
name string
args args
want int
}{
// TODO: Add test cases.
{
name: "OverridePortTakesPriority",
args: args{
input: getLocalPortInput{
OverrideFlag: 5000,
DefaultFromServer: 8080,
Fallback: 5432,
},
},
want: 5000,
},
{
name: "DefaultFromServerTakesPriority",
args: args{
input: getLocalPortInput{
OverrideFlag: 0,
DefaultFromServer: 8080,
Fallback: 5432,
},
},
want: 8080,
},
{
name: "FallbackTakesPriority",
args: args{
input: getLocalPortInput{
OverrideFlag: 0,
DefaultFromServer: 0,
Fallback: 5432,
},
},
want: 5432,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getLocalPort(tt.args.input); got != tt.want {
t.Errorf("getLocalPort() = %v, want %v", got, tt.want)
}
})
}
}
22 changes: 14 additions & 8 deletions pkg/granted/rds/rds.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,17 +570,23 @@ func clientConnectionParameters(c *cli.Context, ensuredAccess *ensureAccessOutpu
// Print the connection information to the user based on the database they are connecting to
// the passwords are always 'password' while the username and database will match that of the target being connected to
yellow := color.New(color.FgYellow)
// the port that the user connects to
overridePort := c.Int("port")
switch ensuredAccess.GrantOutput.RdsDatabase.Engine {
case "postgres", "aurora-postgresql":
port = grab.If(overridePort != 0, strconv.Itoa(overridePort), "5432")
connectionString = yellow.Sprintf("postgresql://%s:[email protected]:%s/%s?sslmode=disable", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database)
cliString = yellow.Sprintf(`psql "postgresql://%s:[email protected]:%s/%s?sslmode=disable"`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database)
port := getLocalPort(getLocalPortInput{
OverrideFlag: c.Int("port"),
DefaultFromServer: int(ensuredAccess.GrantOutput.DefaultLocalPort),
Fallback: 5432,
})
connectionString = yellow.Sprintf("postgresql://%s:[email protected]:%d/%s?sslmode=disable", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database)
cliString = yellow.Sprintf(`psql "postgresql://%s:[email protected]:%d/%s?sslmode=disable"`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database)
case "mysql", "aurora-mysql":
port = grab.If(overridePort != 0, strconv.Itoa(overridePort), "3306")
connectionString = yellow.Sprintf("%s:password@tcp(127.0.0.1:%s)/%s", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database)
cliString = yellow.Sprintf(`mysql -u %s -p'password' -h 127.0.0.1 -P %s %s`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database)
port := getLocalPort(getLocalPortInput{
OverrideFlag: c.Int("port"),
DefaultFromServer: int(ensuredAccess.GrantOutput.DefaultLocalPort),
Fallback: 3306,
})
connectionString = yellow.Sprintf("%s:password@tcp(127.0.0.1:%d)/%s", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database)
cliString = yellow.Sprintf(`mysql -u %s -p'password' -h 127.0.0.1 -P %d %s`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database)
default:
return "", "", "", fmt.Errorf("unsupported database engine: %s, maybe you need to update your `cf` cli", ensuredAccess.GrantOutput.RdsDatabase.Engine)
}
Expand Down

0 comments on commit a92d36e

Please sign in to comment.