-
Notifications
You must be signed in to change notification settings - Fork 105
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow server specified default port to be used + test cases
- Loading branch information
1 parent
94277ec
commit a92d36e
Showing
3 changed files
with
101 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package rds | ||
|
||
type getLocalPortInput struct { | ||
// OverrideFlag is set by the user using the --port flag | ||
OverrideFlag int | ||
// DefaultFromServer is the port number specified by admins in the Terraform provider | ||
DefaultFromServer int | ||
// Fallback is the port to default to if OverrideFlag and DefaultFromServer are not set | ||
Fallback int | ||
} | ||
|
||
// getLocalPort returns the port number to use for the local port | ||
// | ||
// Common Fate allows admins to set default ports in the Terraform provider and | ||
// users to override them with the --port flag when running granted rds proxy --port <port> | ||
// | ||
// The order of priorities is: | ||
// 1. OverrideFlag | ||
// 2. DefaultFromServer | ||
// 3. Fallback | ||
// | ||
// You should set Fallback to 5432 for PostgreSQL and 3306 for MySQL | ||
func getLocalPort(input getLocalPortInput) int { | ||
if input.OverrideFlag != 0 { | ||
return input.OverrideFlag | ||
} | ||
if input.DefaultFromServer != 0 { | ||
return input.DefaultFromServer | ||
} | ||
return input.Fallback | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
package rds | ||
|
||
import "testing" | ||
|
||
func Test_getLocalPort(t *testing.T) { | ||
type args struct { | ||
input getLocalPortInput | ||
} | ||
tests := []struct { | ||
name string | ||
args args | ||
want int | ||
}{ | ||
// TODO: Add test cases. | ||
{ | ||
name: "OverridePortTakesPriority", | ||
args: args{ | ||
input: getLocalPortInput{ | ||
OverrideFlag: 5000, | ||
DefaultFromServer: 8080, | ||
Fallback: 5432, | ||
}, | ||
}, | ||
want: 5000, | ||
}, | ||
{ | ||
name: "DefaultFromServerTakesPriority", | ||
args: args{ | ||
input: getLocalPortInput{ | ||
OverrideFlag: 0, | ||
DefaultFromServer: 8080, | ||
Fallback: 5432, | ||
}, | ||
}, | ||
want: 8080, | ||
}, | ||
{ | ||
name: "FallbackTakesPriority", | ||
args: args{ | ||
input: getLocalPortInput{ | ||
OverrideFlag: 0, | ||
DefaultFromServer: 0, | ||
Fallback: 5432, | ||
}, | ||
}, | ||
want: 5432, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
if got := getLocalPort(tt.args.input); got != tt.want { | ||
t.Errorf("getLocalPort() = %v, want %v", got, tt.want) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -570,17 +570,23 @@ func clientConnectionParameters(c *cli.Context, ensuredAccess *ensureAccessOutpu | |
// Print the connection information to the user based on the database they are connecting to | ||
// the passwords are always 'password' while the username and database will match that of the target being connected to | ||
yellow := color.New(color.FgYellow) | ||
// the port that the user connects to | ||
overridePort := c.Int("port") | ||
switch ensuredAccess.GrantOutput.RdsDatabase.Engine { | ||
case "postgres", "aurora-postgresql": | ||
port = grab.If(overridePort != 0, strconv.Itoa(overridePort), "5432") | ||
connectionString = yellow.Sprintf("postgresql://%s:[email protected]:%s/%s?sslmode=disable", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) | ||
cliString = yellow.Sprintf(`psql "postgresql://%s:[email protected]:%s/%s?sslmode=disable"`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) | ||
port := getLocalPort(getLocalPortInput{ | ||
OverrideFlag: c.Int("port"), | ||
DefaultFromServer: int(ensuredAccess.GrantOutput.DefaultLocalPort), | ||
Fallback: 5432, | ||
}) | ||
connectionString = yellow.Sprintf("postgresql://%s:[email protected]:%d/%s?sslmode=disable", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) | ||
cliString = yellow.Sprintf(`psql "postgresql://%s:[email protected]:%d/%s?sslmode=disable"`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) | ||
case "mysql", "aurora-mysql": | ||
port = grab.If(overridePort != 0, strconv.Itoa(overridePort), "3306") | ||
connectionString = yellow.Sprintf("%s:password@tcp(127.0.0.1:%s)/%s", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) | ||
cliString = yellow.Sprintf(`mysql -u %s -p'password' -h 127.0.0.1 -P %s %s`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) | ||
port := getLocalPort(getLocalPortInput{ | ||
OverrideFlag: c.Int("port"), | ||
DefaultFromServer: int(ensuredAccess.GrantOutput.DefaultLocalPort), | ||
Fallback: 3306, | ||
}) | ||
connectionString = yellow.Sprintf("%s:password@tcp(127.0.0.1:%d)/%s", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) | ||
cliString = yellow.Sprintf(`mysql -u %s -p'password' -h 127.0.0.1 -P %d %s`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) | ||
default: | ||
return "", "", "", fmt.Errorf("unsupported database engine: %s, maybe you need to update your `cf` cli", ensuredAccess.GrantOutput.RdsDatabase.Engine) | ||
} | ||
|