From a92d36ee388ac7bd9b1f1e0ad3afcf384d82a836 Mon Sep 17 00:00:00 2001 From: Vignesh Date: Thu, 31 Oct 2024 12:30:39 -0700 Subject: [PATCH] Allow server specified default port to be used + test cases --- pkg/granted/rds/local_port.go | 31 +++++++++++++++++ pkg/granted/rds/local_port_test.go | 56 ++++++++++++++++++++++++++++++ pkg/granted/rds/rds.go | 22 +++++++----- 3 files changed, 101 insertions(+), 8 deletions(-) create mode 100644 pkg/granted/rds/local_port.go create mode 100644 pkg/granted/rds/local_port_test.go diff --git a/pkg/granted/rds/local_port.go b/pkg/granted/rds/local_port.go new file mode 100644 index 00000000..f985d40d --- /dev/null +++ b/pkg/granted/rds/local_port.go @@ -0,0 +1,31 @@ +package rds + +type getLocalPortInput struct { + // OverrideFlag is set by the user using the --port flag + OverrideFlag int + // DefaultFromServer is the port number specified by admins in the Terraform provider + DefaultFromServer int + // Fallback is the port to default to if OverrideFlag and DefaultFromServer are not set + Fallback int +} + +// getLocalPort returns the port number to use for the local port +// +// Common Fate allows admins to set default ports in the Terraform provider and +// users to override them with the --port flag when running granted rds proxy --port +// +// The order of priorities is: +// 1. OverrideFlag +// 2. DefaultFromServer +// 3. Fallback +// +// You should set Fallback to 5432 for PostgreSQL and 3306 for MySQL +func getLocalPort(input getLocalPortInput) int { + if input.OverrideFlag != 0 { + return input.OverrideFlag + } + if input.DefaultFromServer != 0 { + return input.DefaultFromServer + } + return input.Fallback +} diff --git a/pkg/granted/rds/local_port_test.go b/pkg/granted/rds/local_port_test.go new file mode 100644 index 00000000..85136073 --- /dev/null +++ b/pkg/granted/rds/local_port_test.go @@ -0,0 +1,56 @@ +package rds + +import "testing" + +func Test_getLocalPort(t *testing.T) { + type args struct { + input getLocalPortInput + } + tests := []struct { + name string + args args + want int + }{ + // TODO: Add test cases. + { + name: "OverridePortTakesPriority", + args: args{ + input: getLocalPortInput{ + OverrideFlag: 5000, + DefaultFromServer: 8080, + Fallback: 5432, + }, + }, + want: 5000, + }, + { + name: "DefaultFromServerTakesPriority", + args: args{ + input: getLocalPortInput{ + OverrideFlag: 0, + DefaultFromServer: 8080, + Fallback: 5432, + }, + }, + want: 8080, + }, + { + name: "FallbackTakesPriority", + args: args{ + input: getLocalPortInput{ + OverrideFlag: 0, + DefaultFromServer: 0, + Fallback: 5432, + }, + }, + want: 5432, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getLocalPort(tt.args.input); got != tt.want { + t.Errorf("getLocalPort() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/granted/rds/rds.go b/pkg/granted/rds/rds.go index dcd9263f..193b6eee 100644 --- a/pkg/granted/rds/rds.go +++ b/pkg/granted/rds/rds.go @@ -570,17 +570,23 @@ func clientConnectionParameters(c *cli.Context, ensuredAccess *ensureAccessOutpu // Print the connection information to the user based on the database they are connecting to // the passwords are always 'password' while the username and database will match that of the target being connected to yellow := color.New(color.FgYellow) - // the port that the user connects to - overridePort := c.Int("port") switch ensuredAccess.GrantOutput.RdsDatabase.Engine { case "postgres", "aurora-postgresql": - port = grab.If(overridePort != 0, strconv.Itoa(overridePort), "5432") - connectionString = yellow.Sprintf("postgresql://%s:password@127.0.0.1:%s/%s?sslmode=disable", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) - cliString = yellow.Sprintf(`psql "postgresql://%s:password@127.0.0.1:%s/%s?sslmode=disable"`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) + port := getLocalPort(getLocalPortInput{ + OverrideFlag: c.Int("port"), + DefaultFromServer: int(ensuredAccess.GrantOutput.DefaultLocalPort), + Fallback: 5432, + }) + connectionString = yellow.Sprintf("postgresql://%s:password@127.0.0.1:%d/%s?sslmode=disable", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) + cliString = yellow.Sprintf(`psql "postgresql://%s:password@127.0.0.1:%d/%s?sslmode=disable"`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) case "mysql", "aurora-mysql": - port = grab.If(overridePort != 0, strconv.Itoa(overridePort), "3306") - connectionString = yellow.Sprintf("%s:password@tcp(127.0.0.1:%s)/%s", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) - cliString = yellow.Sprintf(`mysql -u %s -p'password' -h 127.0.0.1 -P %s %s`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) + port := getLocalPort(getLocalPortInput{ + OverrideFlag: c.Int("port"), + DefaultFromServer: int(ensuredAccess.GrantOutput.DefaultLocalPort), + Fallback: 3306, + }) + connectionString = yellow.Sprintf("%s:password@tcp(127.0.0.1:%d)/%s", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) + cliString = yellow.Sprintf(`mysql -u %s -p'password' -h 127.0.0.1 -P %d %s`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database) default: return "", "", "", fmt.Errorf("unsupported database engine: %s, maybe you need to update your `cf` cli", ensuredAccess.GrantOutput.RdsDatabase.Engine) }