Skip to content

Commit

Permalink
Merge pull request #54 from max-ieremenko/feature/53
Browse files Browse the repository at this point in the history
#53: postgres: run scripts as a non-superuser
  • Loading branch information
max-ieremenko authored Jun 13, 2024
2 parents 7e1ca35 + 03039b6 commit 89000db
Show file tree
Hide file tree
Showing 15 changed files with 135 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Build/scripts/Start-Pgsql.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function Start-Pgsql {

$builder = New-Object -TypeName Npgsql.NpgsqlConnectionStringBuilder
$builder["Database"] = "sqldatabasetest"
$builder["Username"] = "postgres"
$builder["Username"] = "adminuser"
$builder["Password"] = "qwerty"
$builder["Timeout"] = 5

Expand Down
2 changes: 1 addition & 1 deletion Examples/ExecuteScriptsFolder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Predefined variables
Opening a connection
========================

Before starting any step SqlDatabase checks if a database, provided in the connection string, exists. If database does not exists the connection will be targeted to `master` for MSSQL and `postgres` for PostgreSQL.
If the database specified in the connection string does not exist, execution will be terminated with the appropriate error.

MSSQL Server script example
=============================
Expand Down
2 changes: 1 addition & 1 deletion Examples/MigrationStepsFolder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Predefined variables
Opening a connection
========================
Before starting any step SqlDatabase checks if a database, provided in the connection string, exists. If database does not exists the connection will be targeted to `master` for MSSQL and `postgres` for PostgreSQL.
If the database specified in the connection string does not exist, execution will be terminated with the appropriate error.
Migration MSSQL Server .sql step example
=============================
Expand Down
23 changes: 21 additions & 2 deletions Sources/Docker/pgsql.create-database.sql
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
CREATE DATABASE sqldatabasetest;
CREATE ROLE adminuser WITH
LOGIN
NOSUPERUSER
INHERIT
CREATEDB
CREATEROLE
NOREPLICATION
PASSWORD 'qwerty';

SET ROLE adminuser;

CREATE DATABASE sqldatabasetest;

\connect sqldatabasetest;

SET ROLE adminuser;

CREATE EXTENSION citext;

CREATE TABLE public.version
Expand All @@ -10,6 +23,12 @@ CREATE TABLE public.version
,version varchar(20) NOT NULL
);

CREATE TABLE public.version2
(
module_name public.citext NOT NULL
,version varchar(20) NOT NULL
);

ALTER TABLE public.version
ADD CONSTRAINT pk_version PRIMARY KEY (module_name);

Expand All @@ -24,4 +43,4 @@ CREATE TYPE public.inventory_item AS (
name text,
supplier_id integer,
price numeric
);
);
2 changes: 1 addition & 1 deletion Sources/SqlDatabase.Adapter.PgSql.Test/app.config
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
<configuration>
<connectionStrings>
<add name="pgsql"
connectionString="Host=localhost;Username=postgres;Password=qwerty;Database=sqldatabasetest;" />
connectionString="Host=localhost;Username=adminuser;Password=qwerty;Database=sqldatabasetest;" />
</connectionStrings>
</configuration>
6 changes: 3 additions & 3 deletions Sources/SqlDatabase.Adapter.PgSql/PgSqlDatabaseAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ public PgSqlDatabaseAdapter(
Pooling = false
};

DatabaseName = builder.Database;
DatabaseName = builder.Database!;
_connectionString = builder.ToString();

builder.Database = null;
builder.Database = "postgres"; // The master will always be set to postgres database
_connectionStringMaster = builder.ToString();

_onConnectionNotice = OnConnectionNotice;
Expand Down Expand Up @@ -72,4 +72,4 @@ private void OnConnectionNotice(object sender, NpgsqlNoticeEventArgs e)
{
_log.Info($"{e.Notice.Severity}: {e.Notice.MessageText}");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void BeforeEachTest()

_database = new Mock<IDatabase>(MockBehavior.Strict);
_database.SetupGet(d => d.Adapter).Returns(adapter.Object);
_database.Setup(d => d.GetServerVersion()).Returns("sql server 1.0");
_database.Setup(d => d.GetServerVersion(true)).Returns("sql server 1.0");

_scriptSequence = new Mock<ICreateScriptSequence>(MockBehavior.Strict);

Expand Down Expand Up @@ -80,8 +80,8 @@ public void ExecuteSequence()
.Setup(f => f.InitializeEnvironment(_log.Object, sequence));

_database
.Setup(d => d.Execute(step1.Object))
.Callback(() => _database.Setup(d => d.Execute(step2.Object)));
.Setup(d => d.ExecuteWithDatabaseCheck(step1.Object))
.Callback(() => _database.Setup(d => d.ExecuteWithDatabaseCheck(step2.Object)));

_scriptSequence
.Setup(s => s.BuildSequence())
Expand Down Expand Up @@ -109,7 +109,7 @@ public void StopExecutionOnError()
.Setup(f => f.InitializeEnvironment(_log.Object, sequence));

_database
.Setup(d => d.Execute(step1.Object))
.Setup(d => d.ExecuteWithDatabaseCheck(step1.Object))
.Throws<InvalidOperationException>();

_scriptSequence
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void BeforeEachTest()

_database = new Mock<IDatabase>(MockBehavior.Strict);
_database.SetupGet(d => d.Adapter).Returns(adapter.Object);
_database.Setup(d => d.GetServerVersion()).Returns("sql server 1.0");
_database.Setup(d => d.GetServerVersion(false)).Returns("sql server 1.0");

_scriptSequence = new Mock<ICreateScriptSequence>(MockBehavior.Strict);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public void BeforeEachTest()

_database = new Mock<IDatabase>(MockBehavior.Strict);
_database.SetupGet(d => d.Adapter).Returns(adapter.Object);
_database.Setup(d => d.GetServerVersion()).Returns("sql server 1.0");
_database.Setup(d => d.GetServerVersion(false)).Returns("sql server 1.0");

_scriptSequence = new Mock<ICreateScriptSequence>(MockBehavior.Strict);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void BeforeEachTest()

_database = new Mock<IDatabase>(MockBehavior.Strict);
_database.SetupGet(d => d.Adapter).Returns(adapter.Object);
_database.Setup(d => d.GetServerVersion()).Returns("sql server 1.0");
_database.Setup(d => d.GetServerVersion(false)).Returns("sql server 1.0");

_scriptSequence = new Mock<IUpgradeScriptSequence>(MockBehavior.Strict);

Expand Down
63 changes: 58 additions & 5 deletions Sources/SqlDatabase.Test/Scripts/DatabaseTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,15 @@ public void GetCurrentVersionModuleNameInvalidVersion()
}

[Test]
public void GetServerVersion()
[TestCase(true)]
[TestCase(false)]
public void GetServerVersion(bool useMasterDatabase)
{
_adapter
.Setup(a => a.GetServerVersionSelectScript())
.Returns("select server version");
_adapter
.Setup(a => a.CreateConnection(true))
.Setup(a => a.CreateConnection(useMasterDatabase))
.Returns(_connection.Object);

_command
Expand All @@ -208,7 +210,7 @@ public void GetServerVersion()
})
.Returns("server version");

var actual = _sut.GetServerVersion();
var actual = _sut.GetServerVersion(useMasterDatabase);

actual.ShouldBe("server version");
}
Expand Down Expand Up @@ -315,6 +317,57 @@ public void Execute(TransactionMode transaction)
{
_sut.Transaction = transaction;

_command
.SetupProperty(c => c.CommandTimeout, 30);
_command
.SetupProperty(c => c.Transaction);

_connection
.Setup(c => c.BeginTransaction(IsolationLevel.ReadCommitted))
.Returns(_transaction.Object);

_adapter
.SetupGet(a => a.DatabaseName)
.Returns("database-name");
_adapter
.Setup(a => a.CreateConnection(false))
.Returns(_connection.Object);

var script = new Mock<IScript>(MockBehavior.Strict);
script
.Setup(s => s.Execute(_command.Object, It.IsNotNull<IVariables>(), It.IsNotNull<ILogger>()))
.Callback<IDbCommand, IVariables, ILogger>((cmd, variables, s) =>
{
cmd.CommandTimeout.ShouldBe(0);
if (transaction == TransactionMode.PerStep)
{
cmd.Transaction.ShouldBe(_transaction.Object);
}
else
{
cmd.Transaction.ShouldBeNull();
}
variables.GetValue("DatabaseName").ShouldBe("database-name");
variables.GetValue("CurrentVersion").ShouldBeNullOrEmpty();
variables.GetValue("TargetVersion").ShouldBeNullOrEmpty();
variables.GetValue("ModuleName").ShouldBeNullOrEmpty();
});

_sut.Execute(script.Object);

script.VerifyAll();
_command.VerifyAll();
}

[Test]
[TestCase(TransactionMode.None)]
[TestCase(TransactionMode.PerStep)]
public void ExecuteWithDatabaseCheck(TransactionMode transaction)
{
_sut.Transaction = transaction;

_command
.SetupProperty(c => c.CommandTimeout, 30);
_command
Expand Down Expand Up @@ -364,7 +417,7 @@ public void Execute(TransactionMode transaction)
variables.GetValue("ModuleName").ShouldBeNullOrEmpty();
});

_sut.Execute(script.Object);
_sut.ExecuteWithDatabaseCheck(script.Object);

script.VerifyAll();
_command.VerifyAll();
Expand Down Expand Up @@ -402,7 +455,7 @@ public void ExecuteDatabaseNotFound()
cmd.Transaction.ShouldBeNull();
});

_sut.Execute(script.Object);
_sut.ExecuteWithDatabaseCheck(script.Object);

script.VerifyAll();
_command.VerifyAll();
Expand Down
4 changes: 3 additions & 1 deletion Sources/SqlDatabase/Commands/DatabaseCommandBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ protected DatabaseCommandBase(IDatabase database, ILogger log)
public void Execute()
{
Greet(Database.Adapter.GetUserFriendlyConnectionString());
Log.Info(Database.GetServerVersion());
Log.Info(GetServerVersion());

ExecuteCore();
}

protected abstract void Greet(string databaseLocation);

protected abstract void ExecuteCore();

protected virtual string GetServerVersion() => Database.GetServerVersion(false);
}
4 changes: 3 additions & 1 deletion Sources/SqlDatabase/Commands/DatabaseCreateCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ protected override void Greet(string databaseLocation)
Log.Info($"Create {databaseLocation}");
}

protected override string GetServerVersion() => Database.GetServerVersion(true);

protected override void ExecuteCore()
{
var sequences = ScriptSequence.BuildSequence();
Expand All @@ -44,7 +46,7 @@ protected override void ExecuteCore()

using (Log.Indent())
{
Database.Execute(script);
Database.ExecuteWithDatabaseCheck(script);
}

Log.Info($"done in {timer.Elapsed}");
Expand Down
51 changes: 33 additions & 18 deletions Sources/SqlDatabase/Scripts/Database.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ public Version GetCurrentVersion(string? moduleName)
}
}

public string GetServerVersion()
public string GetServerVersion(bool useMasterDatabase)
{
using (var connection = Adapter.CreateConnection(true))
using (var connection = Adapter.CreateConnection(useMasterDatabase))
using (var command = connection.CreateCommand())
{
command.CommandText = Adapter.GetServerVersionSelectScript();
Expand Down Expand Up @@ -78,7 +78,21 @@ public void Execute(IScript script)
}
else
{
InvokeExecute(script);
InvokeExecute(script, false);
}
}

public void ExecuteWithDatabaseCheck(IScript script)
{
Variables.DatabaseName = Adapter.DatabaseName;

if (WhatIf)
{
ExecuteWhatIf(script);
}
else
{
InvokeExecute(script, !DatabaseExists());
}
}

Expand Down Expand Up @@ -172,22 +186,8 @@ private void InvokeExecuteUpgrade(IScript script, Version targetVersion)
}
}

private void InvokeExecute(IScript script)
private void InvokeExecute(IScript script, bool useMaster)
{
bool useMaster;

using (var connection = Adapter.CreateConnection(true))
using (var command = connection.CreateCommand())
{
command.CommandTimeout = 0;
connection.Open();

command.CommandText = Adapter.GetDatabaseExistsScript(Variables.DatabaseName!);
var value = command.ExecuteScalar();

useMaster = value == null || Convert.IsDBNull(value);
}

using (var connection = Adapter.CreateConnection(useMaster))
{
connection.Open();
Expand All @@ -206,6 +206,21 @@ private void InvokeExecute(IScript script)
}
}

private bool DatabaseExists()
{
using (var connection = Adapter.CreateConnection(true))
using (var command = connection.CreateCommand())
{
command.CommandTimeout = 0;
connection.Open();

command.CommandText = Adapter.GetDatabaseExistsScript(Variables.DatabaseName!);
var value = command.ExecuteScalar();

return value != null && !Convert.IsDBNull(value);
}
}

private void ExecuteWhatIf(IScript script)
{
Log.Info("what-if mode");
Expand Down
4 changes: 3 additions & 1 deletion Sources/SqlDatabase/Scripts/IDatabase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ internal interface IDatabase
{
IDatabaseAdapter Adapter { get; }

string GetServerVersion();
string GetServerVersion(bool useMasterDatabase);

Version GetCurrentVersion(string? moduleName);

void Execute(IScript script, string moduleName, Version currentVersion, Version targetVersion);

void Execute(IScript script);

void ExecuteWithDatabaseCheck(IScript script);

IEnumerable<IDataReader> ExecuteReader(IScript script);
}

0 comments on commit 89000db

Please sign in to comment.