diff --git a/README.md b/README.md index 2631944c..5e55b712 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ Use `sqlcmd` to create SQL Server and Azure SQL Edge instances using a local con To create a local SQL Server instance with the AdventureWorksLT database restored, query it, and connect to it using Azure Data Studio, run: ``` -sqlcmd create mssql --accept-eula --using https://aka.ms/AdventureWorksLT.bak +sqlcmd create mssql --accept-eula --use https://aka.ms/AdventureWorksLT.bak sqlcmd query "SELECT DB_NAME()" sqlcmd open ads ``` diff --git a/cmd/modern/main.go b/cmd/modern/main.go index 4b358b48..b0dbb791 100644 --- a/cmd/modern/main.go +++ b/cmd/modern/main.go @@ -20,6 +20,7 @@ import ( "github.com/microsoft/go-sqlcmd/internal/output" "github.com/microsoft/go-sqlcmd/internal/output/verbosity" "github.com/microsoft/go-sqlcmd/internal/pal" + "github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer" "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" "github.com/spf13/cobra" "path" @@ -95,9 +96,7 @@ func initializeEnvVars() { os.Setenv("SQLCMDPASSWORD", password) } } - } - } // isFirstArgModernCliSubCommand is TEMPORARY code, to be removed when @@ -132,6 +131,10 @@ func initializeCallback() { HintHandler: displayHints, LineBreak: sqlcmd.SqlcmdEol, }) + mssqlcontainer.Initialize(mssqlcontainer.InitializeOptions{ + ErrorHandler: checkErr, + TraceHandler: outputter.Tracef, + }) config.SetFileName(rootCmd.configFilename) config.Load() } diff --git a/cmd/modern/root.go b/cmd/modern/root.go index b22fb6bc..f0101c31 100644 --- a/cmd/modern/root.go +++ b/cmd/modern/root.go @@ -27,7 +27,7 @@ type Root struct { // It also provides usage examples for sqlcmd. func (c *Root) DefineCommand(...cmdparser.CommandOptions) { // Example usage steps - steps := []string{"sqlcmd create mssql --accept-eula --using https://aka.ms/AdventureWorksLT.bak"} + steps := []string{"sqlcmd create mssql --accept-eula --use https://aka.ms/AdventureWorksLT.bak"} if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { steps = append(steps, "sqlcmd open ads") @@ -69,6 +69,7 @@ func (c *Root) SubCommands() []cmdparser.Command { cmdparser.New[*root.Query](dependencies), cmdparser.New[*root.Start](dependencies), cmdparser.New[*root.Stop](dependencies), + cmdparser.New[*root.Use](dependencies), cmdparser.New[*root.Uninstall](dependencies), } diff --git a/cmd/modern/root/config/connection-strings.go b/cmd/modern/root/config/connection-strings.go index c9c72378..63b85ed4 100644 --- a/cmd/modern/root/config/connection-strings.go +++ b/cmd/modern/root/config/connection-strings.go @@ -71,7 +71,7 @@ func (c *ConnectionStrings) run() { if endpoint.AssetDetails != nil && endpoint.AssetDetails.ContainerDetails != nil { controller := container.NewController() if controller.ContainerRunning(endpoint.AssetDetails.ContainerDetails.Id) { - s := sql.New(sql.SqlOptions{}) + s := sql.NewSql(sql.SqlOptions{}) s.Connect(endpoint, user, sql.ConnectOptions{Interactive: false}) c.database = s.ScalarString("PRINT DB_NAME()") } else { diff --git a/cmd/modern/root/install/edge_test.go b/cmd/modern/root/install/edge_test.go index c01d7dc0..18c04f2f 100644 --- a/cmd/modern/root/install/edge_test.go +++ b/cmd/modern/root/install/edge_test.go @@ -25,7 +25,7 @@ func TestInstallEdge(t *testing.T) { cmdparser.TestCmd[*edge.GetTags]() cmdparser.TestCmd[*Edge]( fmt.Sprintf( - `--accept-eula --user-database foo --errorlog-wait-line "Hello from Docker!" --registry %v --repo %v`, + `--accept-eula --database foo --errorlog-wait-line "Hello from Docker!" --registry %v --repo %v`, registry, repo)) diff --git a/cmd/modern/root/install/mssql-base.go b/cmd/modern/root/install/mssql-base.go index 3a2a11ed..ec92af81 100644 --- a/cmd/modern/root/install/mssql-base.go +++ b/cmd/modern/root/install/mssql-base.go @@ -5,9 +5,6 @@ package install import ( "fmt" - "net/url" - "path" - "path/filepath" "runtime" "strings" @@ -15,12 +12,13 @@ import ( "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" "github.com/microsoft/go-sqlcmd/internal/container" - "github.com/microsoft/go-sqlcmd/internal/http" "github.com/microsoft/go-sqlcmd/internal/localizer" "github.com/microsoft/go-sqlcmd/internal/output" "github.com/microsoft/go-sqlcmd/internal/pal" "github.com/microsoft/go-sqlcmd/internal/secret" "github.com/microsoft/go-sqlcmd/internal/sql" + "github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer/ingest" + "github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer/ingest/mechanism" "github.com/spf13/viper" ) @@ -55,7 +53,8 @@ type MssqlBase struct { port int - usingDatabaseUrl string + useDatabaseUrl string + useMechanism string unitTesting bool @@ -101,6 +100,14 @@ func (c *MssqlBase) AddFlags( String: &c.defaultDatabase, Name: "user-database", Shorthand: "u", + Hidden: true, + Usage: localizer.Sprintf("[DEPRECATED use --database] Create a user database and set it as the default for login"), + }) + + addFlag(cmdparser.FlagOptions{ + String: &c.defaultDatabase, + Name: "database", + Shorthand: "d", Usage: localizer.Sprintf("Create a user database and set it as the default for login"), }) @@ -213,10 +220,25 @@ func (c *MssqlBase) AddFlags( }) addFlag(cmdparser.FlagOptions{ - String: &c.usingDatabaseUrl, + String: &c.useDatabaseUrl, DefaultString: "", Name: "using", - Usage: localizer.Sprintf("Download (into container) and attach database (.bak) from URL"), + Hidden: true, + Usage: localizer.Sprintf("[DEPRECATED use --use] Download %q and use database", ingest.ValidFileExtensions()), + }) + + addFlag(cmdparser.FlagOptions{ + String: &c.useDatabaseUrl, + DefaultString: "", + Name: "use", + Usage: localizer.Sprintf("Download %q and use database", ingest.ValidFileExtensions()), + }) + + addFlag(cmdparser.FlagOptions{ + String: &c.useMechanism, + DefaultString: "", + Name: "use-mechanism", + Usage: localizer.Sprintf("Mechanism to use to bring database online (%s)", strings.Join(mechanism.Mechanisms(), ",")), }) } @@ -229,10 +251,10 @@ func (c *MssqlBase) AddFlags( // If the EULA has not been accepted, it prints an error message with suggestions for how to proceed, // and exits the program. func (c *MssqlBase) Run() { - output := c.Cmd.Output() - var imageName string + output := c.Cmd.Output() + if !c.acceptEula && viper.GetString("ACCEPT_EULA") == "" { output.FatalWithHints( []string{localizer.Sprintf("Either, add the %s flag to the command-line", localizer.AcceptEulaFlag), @@ -240,12 +262,9 @@ func (c *MssqlBase) Run() { localizer.Sprintf("EULA not accepted")) } - imageName = fmt.Sprintf( - "%s/%s:%s", - c.registry, - c.repo, - c.tag) + imageName = fmt.Sprintf("%s/%s:%s", c.registry, c.repo, c.tag) + // If no context name provided, set it to the default (e.g. mssql or edge) if c.contextName == "" { c.contextName = c.defaultContextName } @@ -253,7 +272,7 @@ func (c *MssqlBase) Run() { c.createContainer(imageName, c.contextName) } -// createContainer installs an image for a SQL Server container. The image +// createContainer creates a SQL Server container for an image. The image // is specified by imageName, and the container will be given the name contextName. // If the useCached flag is set, the function will skip downloading the image // from the internet. The function outputs progress messages to the command-line @@ -261,77 +280,71 @@ func (c *MssqlBase) Run() { // command-line and the program will exit. func (c *MssqlBase) createContainer(imageName string, contextName string) { output := c.Cmd.Output() + controller := container.NewController() saPassword := c.generatePassword() - env := []string{ - "ACCEPT_EULA=Y", - fmt.Sprintf("MSSQL_SA_PASSWORD=%s", saPassword), - fmt.Sprintf("MSSQL_COLLATION=%s", c.collation), - } - if c.port == 0 { c.port = config.FindFreePortForTds() } // Do an early exit if url doesn't exist - if c.usingDatabaseUrl != "" { - c.validateUsingUrlExists() + var useDatabase ingest.Ingest + if c.useDatabaseUrl != "" { + useDatabase = c.verifyUseSourceFileExists(controller, output) } if c.defaultDatabase != "" { if !c.validateDbName(c.defaultDatabase) { - output.Fatalf(localizer.Sprintf("--user-database %q contains non-ASCII chars and/or quotes", c.defaultDatabase)) + output.Fatalf(localizer.Sprintf("--database %q contains non-ASCII chars and/or quotes", c.defaultDatabase)) } } - controller := container.NewController() - if !c.useCached { c.downloadImage(imageName, output, controller) } + runOptions := container.RunOptions{ + Port: c.port, + Name: c.name, + Hostname: c.hostname, + Architecture: c.architecture, + Os: c.os} + + runOptions.Env = []string{ + "ACCEPT_EULA=Y", + fmt.Sprintf("MSSQL_SA_PASSWORD=%s", saPassword), + fmt.Sprintf("MSSQL_COLLATION=%s", c.collation)} + output.Infof(localizer.Sprintf("Starting %v", imageName)) containerId := controller.ContainerRun( imageName, - env, - c.port, - c.name, - c.hostname, - c.architecture, - c.os, - []string{}, - false, + runOptions, ) previousContextName := config.CurrentContextName() - userName := pal.UserName() - password := c.generatePassword() - // Save the config now, so user can uninstall/delete, even if mssql in the container // fails to start - config.AddContextWithContainer( - contextName, - imageName, - c.port, - containerId, - userName, - password, - c.passwordEncryption, - ) + contextOptions := config.ContextOptions{ + ImageName: imageName, + PortNumber: c.port, + ContainerId: containerId, + Username: pal.UserName(), + Password: c.generatePassword(), + PasswordEncryption: c.passwordEncryption} + config.AddContextWithContainer(contextName, contextOptions) output.Infof( - localizer.Sprintf("Created context %q in \"%s\", configuring user account...", + localizer.Sprintf("Created context %q in \"%s\", configuring user account", config.CurrentContextName(), config.GetConfigFileUsed())) - controller.ContainerWaitForLogEntry( - containerId, c.errorLogEntryToWaitFor) + controller.ContainerWaitForLogEntry(containerId, c.errorLogEntryToWaitFor) output.Infof( localizer.Sprintf("Disabled %q account (and rotated %q password). Creating user %q", "sa", "sa", - userName)) + contextOptions.Username)) endpoint, _ := config.CurrentContext() @@ -339,11 +352,11 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) { // // For Unit Testing we use the Docker Hello World container, which // starts much faster than the SQL Server container! + sqlOptions := sql.SqlOptions{} if c.errorLogEntryToWaitFor == "Hello from Docker!" { - c.sql = sql.New(sql.SqlOptions{UnitTesting: true}) - } else { - c.sql = sql.New(sql.SqlOptions{UnitTesting: false}) + sqlOptions.UnitTesting = true } + c.sql = sql.NewSql(sqlOptions) saUser := &sqlconfig.User{ AuthenticationType: "basic", @@ -353,17 +366,23 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) { Password: secret.Encode(saPassword, c.passwordEncryption)}, Name: "sa"} - c.sql.Connect(endpoint, saUser, sql.ConnectOptions{Database: "master", Interactive: false}) + // Connect to master database on SQL Server in the container as `sa` + c.sql.Connect(endpoint, saUser, sql.ConnectOptions{Database: "master"}) - c.createNonSaUser(userName, password) + // Create a new (non-sa) SQL Server user + c.createUser(contextOptions.Username, contextOptions.Password) - // Download and restore DB if asked - if c.usingDatabaseUrl != "" { - c.downloadAndRestoreDb( - controller, - containerId, - userName, - ) + // Download and restore/attach etc. DB if asked + if useDatabase != nil { + if useDatabase.IsRemoteUrl() { + output.Infof("Downloading %q to container", useDatabase.UrlFilename()) + } else { + output.Infof("Copying %q to container", useDatabase.UrlFilename()) + } + useDatabase.CopyToContainer(containerId) + + output.Infof("Bringing database %q online", useDatabase.DatabaseName()) + useDatabase.BringOnline(c.sql.Query, contextOptions.Username, contextOptions.Password) } hints := [][]string{} @@ -396,51 +415,47 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) { ) } -func (c *MssqlBase) validateUsingUrlExists() { - output := c.Cmd.Output() - databaseUrl := extractUrl(c.usingDatabaseUrl) - u, err := url.Parse(databaseUrl) - c.CheckErr(err) +func (c *MssqlBase) verifyUseSourceFileExists( + controller *container.Controller, + output *output.Output, +) (useDatabase ingest.Ingest) { + useDatabase = ingest.NewIngest(c.useDatabaseUrl, controller, ingest.IngestOptions{ + Mechanism: c.useMechanism, + }) - if u.Scheme != "http" && u.Scheme != "https" { + if !useDatabase.IsValidFileExtension() { output.FatalfWithHints( []string{ - localizer.Sprintf("--using URL must be http or https"), + fmt.Sprintf( + localizer.Sprintf("--use must be a path to a file with a %q extension"), + ingest.ValidFileExtensions(), + ), }, - localizer.Sprintf("%q is not a valid URL for --using flag", c.usingDatabaseUrl)) + localizer.Sprintf("%q is not a valid file extension for --use flag"), useDatabase.UserProvidedFileExt()) } - if u.Path == "" { + if useDatabase.IsRemoteUrl() && !useDatabase.IsValidScheme() { output.FatalfWithHints( []string{ - localizer.Sprintf("--using URL must have a path to .bak file"), + localizer.Sprintf("--use URL must one of %q"), + strings.Join(useDatabase.ValidSchemes(), ", "), }, - localizer.Sprintf("%q is not a valid URL for --using flag", c.usingDatabaseUrl)) + localizer.Sprintf("%q is not a valid URL for --use flag", c.useDatabaseUrl)) } - // At the moment we only support attaching .bak files, but we should - // support .bacpacs and .mdfs in the future - if _, file := filepath.Split(u.Path); filepath.Ext(file) != ".bak" { + if !useDatabase.SourceFileExists() { output.FatalfWithHints( - []string{ - localizer.Sprintf("--using file URL must be a .bak file"), - }, - localizer.Sprintf("Invalid --using file type")) + []string{localizer.Sprintf("File does not exist at URL %q", c.useDatabaseUrl)}, + "Unable to download file") } - - // Verify the url actually exists, and early exit if it doesn't - urlExists(databaseUrl, output) -} - -func (c *MssqlBase) query(commandText string) { - c.sql.Query(commandText) + return } -// createNonSaUser creates a user (non-sa) and assigns the sysadmin role +// createUser creates a user (non-sa) and assigns the sysadmin role // to the user. It also creates a default database with the provided name // and assigns the default database to the user. Finally, it disables // the sa account and rotates the sa password for security reasons. -func (c *MssqlBase) createNonSaUser( +func (c *MssqlBase) createUser( userName string, password string, ) { @@ -453,7 +468,7 @@ func (c *MssqlBase) createNonSaUser( // Create the default database, if it isn't a downloaded database output.Infof(localizer.Sprintf("Creating default database [%s]", defaultDatabase)) - c.query(fmt.Sprintf("CREATE DATABASE [%s]", defaultDatabase)) + c.sql.Query(fmt.Sprintf("CREATE DATABASE [%s]", defaultDatabase)) } const createLogin = `CREATE LOGIN [%s] @@ -465,139 +480,29 @@ CHECK_POLICY=OFF` @loginame = N'%s', @rolename = N'sysadmin'` - c.query(fmt.Sprintf(createLogin, userName, password, defaultDatabase)) - c.query(fmt.Sprintf(addSrvRoleMember, userName)) + if c.defaultDatabase != "" { + defaultDatabase = c.defaultDatabase + + // Create the default database, if it isn't a downloaded database + output.Infof("Creating default database [%s]", defaultDatabase) + c.sql.Query(fmt.Sprintf("CREATE DATABASE [%s]", defaultDatabase)) + } + + c.sql.Query(fmt.Sprintf(createLogin, userName, password, defaultDatabase)) + c.sql.Query(fmt.Sprintf(addSrvRoleMember, userName)) // Correct safety protocol is to rotate the sa password, because the first // sa password has been in the docker environment (as SA_PASSWORD) - c.query(fmt.Sprintf("ALTER LOGIN [sa] WITH PASSWORD = N'%s';", + c.sql.Query(fmt.Sprintf("ALTER LOGIN [sa] WITH PASSWORD = N'%s';", c.generatePassword())) - c.query("ALTER LOGIN [sa] DISABLE") + c.sql.Query("ALTER LOGIN [sa] DISABLE") if c.defaultDatabase != "" { - c.query(fmt.Sprintf("ALTER AUTHORIZATION ON DATABASE::[%s] TO %s", + c.sql.Query(fmt.Sprintf("ALTER AUTHORIZATION ON DATABASE::[%s] TO %s", defaultDatabase, userName)) } } -func getDbNameAsIdentifier(dbName string) string { - escapedDbNAme := strings.ReplaceAll(dbName, "'", "''") - return strings.ReplaceAll(escapedDbNAme, "]", "]]") -} - -func getDbNameAsNonIdentifier(dbName string) string { - return strings.ReplaceAll(dbName, "]", "]]") -} - -// parseDbName returns the databaseName from --using arg -// It sets database name to the specified database name -// or in absence of it, it is set to the filename without -// extension. -func parseDbName(usingDbUrl string) string { - u, _ := url.Parse(usingDbUrl) - dbToken := path.Base(u.Path) - if dbToken != "." && dbToken != "/" { - lastIdx := strings.LastIndex(dbToken, ".bak") - if lastIdx != -1 { - //Get file name without extension - fileName := dbToken[0:lastIdx] - lastIdx += 5 - if lastIdx >= len(dbToken) { - return fileName - } - //Return database name if it was specified - return dbToken[lastIdx:] - } - } - return "" -} - -func extractUrl(usingArg string) string { - urlEndIdx := strings.LastIndex(usingArg, ".bak") - if urlEndIdx != -1 { - return usingArg[0:(urlEndIdx + 4)] - } - return usingArg -} - -func (c *MssqlBase) downloadAndRestoreDb( - controller *container.Controller, - containerId string, - userName string, -) { - output := c.Cmd.Output() - databaseName := parseDbName(c.usingDatabaseUrl) - databaseUrl := extractUrl(c.usingDatabaseUrl) - - _, file := filepath.Split(databaseUrl) - - // Download file from URL into container - output.Infof(localizer.Sprintf("Downloading %s", file)) - - temporaryFolder := "/var/opt/mssql/backup" - - controller.DownloadFile( - containerId, - databaseUrl, - temporaryFolder, - ) - - // Restore database from file - output.Infof(localizer.Sprintf("Restoring database %s", databaseName)) - - dbNameAsIdentifier := getDbNameAsIdentifier(databaseName) - dbNameAsNonIdentifier := getDbNameAsNonIdentifier(databaseName) - - text := `SET NOCOUNT ON; - --- Build a SQL Statement to restore any .bak file to the Linux filesystem -DECLARE @sql NVARCHAR(max) - --- This table definition works since SQL Server 2017, therefore --- works for all SQL Server containers (which started in 2017) -DECLARE @fileListTable TABLE ( - [LogicalName] NVARCHAR(128), - [PhysicalName] NVARCHAR(260), - [Type] CHAR(1), - [FileGroupName] NVARCHAR(128), - [Size] NUMERIC(20,0), - [MaxSize] NUMERIC(20,0), - [FileID] BIGINT, - [CreateLSN] NUMERIC(25,0), - [DropLSN] NUMERIC(25,0), - [UniqueID] UNIQUEIDENTIFIER, - [ReadOnlyLSN] NUMERIC(25,0), - [ReadWriteLSN] NUMERIC(25,0), - [BackupSizeInBytes] BIGINT, - [SourceBlockSize] INT, - [FileGroupID] INT, - [LogGroupGUID] UNIQUEIDENTIFIER, - [DifferentialBaseLSN] NUMERIC(25,0), - [DifferentialBaseGUID] UNIQUEIDENTIFIER, - [IsReadOnly] BIT, - [IsPresent] BIT, - [TDEThumbprint] VARBINARY(32), - [SnapshotURL] NVARCHAR(360) -) - -INSERT INTO @fileListTable -EXEC('RESTORE FILELISTONLY FROM DISK = ''%s/%s''') -SET @sql = 'RESTORE DATABASE [%s] FROM DISK = ''%s/%s'' WITH ' -SELECT @sql = @sql + char(13) + ' MOVE ''' + LogicalName + ''' TO ''/var/opt/mssql/' + LogicalName + '.' + RIGHT(PhysicalName,CHARINDEX('\',PhysicalName)) + ''',' -FROM @fileListTable -WHERE IsPresent = 1 -SET @sql = SUBSTRING(@sql, 1, LEN(@sql)-1) -EXEC(@sql)` - - c.query(fmt.Sprintf(text, temporaryFolder, file, dbNameAsIdentifier, temporaryFolder, file)) - - alterDefaultDb := fmt.Sprintf( - "ALTER LOGIN [%s] WITH DEFAULT_DATABASE = [%s]", - userName, - dbNameAsNonIdentifier) - c.query(alterDefaultDb) -} - func (c *MssqlBase) downloadImage( imageName string, output *output.Output, @@ -621,15 +526,6 @@ func (c *MssqlBase) downloadImage( } } -// Verify the file exists at the URL -func urlExists(url string, output *output.Output) { - if !http.UrlExists(url) { - output.FatalfWithHints( - []string{localizer.Sprintf("File does not exist at URL")}, - localizer.Sprintf("Unable to download file")) - } -} - func (c *MssqlBase) generatePassword() (password string) { password = secret.Generate( c.passwordLength, diff --git a/cmd/modern/root/install/mssql.go b/cmd/modern/root/install/mssql.go index 9465652c..a6e5ee0c 100644 --- a/cmd/modern/root/install/mssql.go +++ b/cmd/modern/root/install/mssql.go @@ -36,13 +36,13 @@ func (c *Mssql) DefineCommand(...cmdparser.CommandOptions) { }, { Description: localizer.Sprintf("Create SQL Server, download and attach AdventureWorks sample database"), - Steps: []string{"sqlcmd create mssql --using https://aka.ms/AdventureWorksLT.bak"}}, + Steps: []string{"sqlcmd create mssql --use https://aka.ms/AdventureWorksLT.bak"}}, { Description: localizer.Sprintf("Create SQL Server, download and attach AdventureWorks sample database with different database name"), - Steps: []string{"sqlcmd create mssql --using https://aka.ms/AdventureWorksLT.bak,adventureworks"}}, + Steps: []string{"sqlcmd create mssql --use https://aka.ms/AdventureWorksLT.bak,adventureworks"}}, { Description: localizer.Sprintf("Create SQL Server with an empty user database"), - Steps: []string{"sqlcmd create mssql --user-database db1"}}, + Steps: []string{"sqlcmd create mssql --database db1"}}, { Description: localizer.Sprintf("Install/Create SQL Server with full logging"), Steps: []string{"sqlcmd create mssql --verbosity 4"}}, diff --git a/cmd/modern/root/install/mssql_test.go b/cmd/modern/root/install/mssql_test.go index 28d80674..91ae7bc4 100644 --- a/cmd/modern/root/install/mssql_test.go +++ b/cmd/modern/root/install/mssql_test.go @@ -25,7 +25,7 @@ func TestInstallMssql(t *testing.T) { cmdparser.TestCmd[*mssql.GetTags]() cmdparser.TestCmd[*Mssql]( fmt.Sprintf( - `--accept-eula --user-database foo --errorlog-wait-line "Hello from Docker!" --registry %v --repo %v`, + `--accept-eula --database foo --errorlog-wait-line "Hello from Docker!" --registry %v --repo %v`, registry, repo)) @@ -54,34 +54,34 @@ func TestNegInstallMssql2(t *testing.T) { func TestNegInstallMssql3(t *testing.T) { cmdparser.TestSetup(t) assert.Panics(t, func() { - cmdparser.TestCmd[*Mssql]("--accept-eula --using https://does/not/exist.bak") + cmdparser.TestCmd[*Mssql]("--accept-eula --use https://does/not/exist.bak") }) } func TestNegInstallMssql4(t *testing.T) { cmdparser.TestSetup(t) assert.Panics(t, func() { - cmdparser.TestCmd[*Mssql]("--accept-eula --user-database bad'name") + cmdparser.TestCmd[*Mssql]("--accept-eula --database bad'name") }) } func TestNegInstallMssql5(t *testing.T) { cmdparser.TestSetup(t) assert.Panics(t, func() { - cmdparser.TestCmd[*Mssql]("--accept-eula --using https://not/bak/file") + cmdparser.TestCmd[*Mssql]("--accept-eula --use https://not/bak/file") }) } func TestNegInstallMssql6(t *testing.T) { cmdparser.TestSetup(t) assert.Panics(t, func() { - cmdparser.TestCmd[*Mssql]("--accept-eula --using file://not/http") + cmdparser.TestCmd[*Mssql]("--accept-eula --use file://not/http") }) } func TestNegInstallMssql7(t *testing.T) { cmdparser.TestSetup(t) assert.Panics(t, func() { - cmdparser.TestCmd[*Mssql]("--accept-eula --using https://aka.ms/AdventureWorksLT") + cmdparser.TestCmd[*Mssql]("--accept-eula --use https://aka.ms/AdventureWorksLT") }) } diff --git a/cmd/modern/root/query.go b/cmd/modern/root/query.go index b61958ea..8c5372c8 100644 --- a/cmd/modern/root/query.go +++ b/cmd/modern/root/query.go @@ -75,7 +75,7 @@ func (c *Query) DefineCommand(...cmdparser.CommandOptions) { func (c *Query) run() { endpoint, user := config.CurrentContext() - s := sql.New(sql.SqlOptions{}) + s := sql.NewSql(sql.SqlOptions{}) if c.text == "" { s.Connect(endpoint, user, sql.ConnectOptions{Database: c.database, Interactive: true}) } else { diff --git a/cmd/modern/root/use.go b/cmd/modern/root/use.go new file mode 100644 index 00000000..731b8190 --- /dev/null +++ b/cmd/modern/root/use.go @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package root + +import ( + "fmt" + "github.com/microsoft/go-sqlcmd/internal/cmdparser" + "github.com/microsoft/go-sqlcmd/internal/config" + "github.com/microsoft/go-sqlcmd/internal/container" + "github.com/microsoft/go-sqlcmd/internal/localizer" + "github.com/microsoft/go-sqlcmd/internal/secret" + "github.com/microsoft/go-sqlcmd/internal/sql" + "github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer/ingest" + "github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer/ingest/mechanism" + "strings" +) + +type Use struct { + cmdparser.Cmd + + url string + useMechanism string + + sql sql.Sql +} + +func (c *Use) DefineCommand(...cmdparser.CommandOptions) { + examples := []cmdparser.ExampleOptions{ + { + Description: "Download AdventureWorksLT into container for current context, set as default database", + Steps: []string{`sqlcmd use https://aka.ms/AdventureWorksLT.bak`}}, + } + + options := cmdparser.CommandOptions{ + Use: "use", + Short: fmt.Sprintf("Download database (into container) (%s)", ingest.ValidFileExtensions()), + Examples: examples, + Run: c.run, + FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagOptions{Flag: "url", Value: &c.url}, + } + + c.Cmd.DefineCommand(options) + + c.AddFlag(cmdparser.FlagOptions{ + String: &c.url, + Name: "url", + Usage: "Name of context to set as current context"}) + + c.AddFlag(cmdparser.FlagOptions{ + String: &c.useMechanism, + DefaultString: "", + Name: "use-mechanism", + Usage: localizer.Sprintf("Mechanism to use to bring database online (%s)", strings.Join(mechanism.Mechanisms(), ",")), + }) +} + +func (c *Use) run() { + output := useOutput{output: c.Output()} + + controller := container.NewController() + id := config.ContainerId() + + if !config.CurrentContextEndpointHasContainer() { + output.FatalNoContainerInCurrentContext() + } + + if !controller.ContainerRunning(id) { + output.FatalContainerNotRunning() + } + + endpoint, user := config.CurrentContext() + + c.sql = sql.NewSql(sql.SqlOptions{}) + c.sql.Connect(endpoint, user, sql.ConnectOptions{Database: "master"}) + + useDatabase := ingest.NewIngest(c.url, controller, ingest.IngestOptions{ + Mechanism: c.useMechanism, + }) + + if !useDatabase.SourceFileExists() { + output.FatalDatabaseSourceFileNotExist(c.url) + } + + // Copy source file (e.g. .bak/.bacpac etc.) for database to be made available to container + useDatabase.CopyToContainer(id) + + output.output.Infof("Bringing %q online using %q method", + useDatabase.DatabaseName(), + useDatabase.OnlineMethod(), + ) + + useDatabase.BringOnline( + c.sql.Query, + user.BasicAuth.Username, + secret.Decode(user.BasicAuth.Password, user.BasicAuth.PasswordEncryption), + ) + + output.InfoDatabaseOnline(useDatabase.DatabaseName()) +} + +func (c *Use) query(commandText string) { + c.sql.Query(commandText) +} diff --git a/cmd/modern/root/use_output.go b/cmd/modern/root/use_output.go new file mode 100644 index 00000000..b29eb0bc --- /dev/null +++ b/cmd/modern/root/use_output.go @@ -0,0 +1,44 @@ +package root + +import ( + "fmt" + "github.com/microsoft/go-sqlcmd/internal/output" + "runtime" +) + +type useOutput struct { + output.Output + output *output.Output +} + +func (u *useOutput) FatalNoContainerInCurrentContext() { + u.output.FatalfWithHintExamples([][]string{ + {"Create a context with a container", "sqlcmd create mssql"}, + }, "Current context does not have a container") +} + +func (u *useOutput) FatalContainerNotRunning() { + u.output.FatalfWithHintExamples([][]string{ + {"Start container for current context", "sqlcmd start"}, + }, "Container for current context is not running") +} + +func (u *useOutput) FatalDatabaseSourceFileNotExist(url string) { + u.output.FatalfWithHints( + []string{fmt.Sprintf("File does not exist at URL %q", url)}, + "Unable to download file to container") +} + +func (u *useOutput) InfoDatabaseOnline(databaseName string) { + hints := [][]string{} + + // TODO: sqlcmd open ads only support on Windows/Mac right now, add Linux support + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + hints = append(hints, []string{"Open in Azure Data Studio", "sqlcmd open ads"}) + } + + hints = append(hints, []string{"Run a query", "sqlcmd query \"SELECT DB_NAME()\""}) + hints = append(hints, []string{"See connection strings", "sqlcmd config connection-strings"}) + + u.output.InfofWithHintExamples(hints, "Database %q is now online", databaseName) +} diff --git a/internal/cmdparser/cmd.go b/internal/cmdparser/cmd.go index f6dcd50b..50f6faf9 100644 --- a/internal/cmdparser/cmd.go +++ b/internal/cmdparser/cmd.go @@ -85,6 +85,10 @@ func (c *Cmd) AddFlag(options FlagOptions) { options.Usage) } } + + if options.Hidden { + c.command.PersistentFlags().MarkHidden(options.Name) + } } // DefineCommand defines a command with the provided CommandOptions and adds diff --git a/internal/cmdparser/options.go b/internal/cmdparser/options.go index 499c8622..dc28f991 100644 --- a/internal/cmdparser/options.go +++ b/internal/cmdparser/options.go @@ -13,20 +13,19 @@ type AlternativeForFlagOptions struct { Value *string } -// FlagOptions type represents options for defining a flag for a command-line -// interface. The Name and Shorthand fields specify the long and short names -// for the flag, respectively. The Usage field is a string that describes how the -// flag should be used. The String, DefaultString, Int, DefaultInt, Bool, and -// DefaultBool fields are used to specify the type and default value of the flag, -// if it is a string, int, or bool type. The String and Int fields should be pointers -// to the variables that will store the flag's value, and the Bool field should be -// a pointer to a bool variable that will be set to true if the flag is present. The -// DefaultString, DefaultInt, and DefaultBool fields are the default values to -// use if the flag is not provided by the user. +// FlagOptions type represents options for defining a flag for a CLI. The Name +// and Shorthand fields specify the long and short names for the flag, respectively. +// The Usage field is a string that describes how the flag should be used. If you +// want the flag hidden from the --help, see the Hidden field to true. +// The String, DefaultString, Int, DefaultInt, Bool, and DefaultBool fields are +// used to specify the type and default value of the flag, use only one of these pairs +// (the one that match the type for the flag value). type FlagOptions struct { - Name string - Shorthand string - Usage string + Name string // e.g. --database + Shorthand string // e.g. -d + Usage string // e.g. "The database to connect to" + + Hidden bool // hide the flag from help (use for deprecated flags) String *string DefaultString string diff --git a/internal/config/config.go b/internal/config/config.go index 1b24695c..b579ba25 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -92,47 +92,42 @@ func IsEmpty() (isEmpty bool) { // requested. The updated configuration is saved to file. func AddContextWithContainer( contextName string, - imageName string, - portNumber int, - containerId string, - username string, - password string, - passwordEncryption string, + options ContextOptions, ) { - if containerId == "" { + if options.ContainerId == "" { panic("containerId must be provided") } - if imageName == "" { + if options.ImageName == "" { panic("imageName must be provided") } - if portNumber == 0 { + if options.PortNumber == 0 { panic("portNumber must be non-zero") } - if username == "" { + if options.Username == "" { panic("username must be provided") } - if password == "" { + if options.Password == "" { panic("password must be provided") } if contextName == "" { panic("contextName must be provided") } - contextName = FindUniqueContextName(contextName, username) + contextName = FindUniqueContextName(contextName, options.Username) endPointName := FindUniqueEndpointName(contextName) - userName := username + "@" + contextName + userName := options.Username + "@" + contextName config.CurrentContext = contextName config.Endpoints = append(config.Endpoints, Endpoint{ AssetDetails: &AssetDetails{ ContainerDetails: &ContainerDetails{ - Id: containerId, - Image: imageName}, + Id: options.ContainerId, + Image: options.ImageName}, }, EndpointDetails: EndpointDetails{ Address: "127.0.0.1", - Port: portNumber, + Port: options.PortNumber, }, Name: endPointName, }) @@ -148,9 +143,9 @@ func AddContextWithContainer( user := User{ AuthenticationType: "basic", BasicAuth: &BasicAuthDetails{ - Username: username, - PasswordEncryption: passwordEncryption, - Password: encryptCallback(password, passwordEncryption), + Username: options.Username, + PasswordEncryption: options.PasswordEncryption, + Password: encryptCallback(options.Password, options.PasswordEncryption), }, Name: userName, } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 7e2540c2..aa2f8fed 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -131,7 +131,16 @@ func TestConfig(t *testing.T) { ContainerId() RemoveCurrentContext() RemoveCurrentContext() - AddContextWithContainer("context", "imageName", 1433, "containerId", "user", "password", "none") + + options := ContextOptions{ + ImageName: "imageName", + PortNumber: 1433, + ContainerId: "containerId", + Username: "user", + Password: "password", + PasswordEncryption: "none", + } + AddContextWithContainer("context", options) RemoveCurrentContext() DeleteEndpoint("endpoint") DeleteContext("context") @@ -324,7 +333,15 @@ func TestAddContextWithContainerPanic(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert.Panics(t, func() { - AddContextWithContainer(tt.args.contextName, tt.args.imageName, tt.args.portNumber, tt.args.containerId, tt.args.username, tt.args.password, tt.args.passwordEncryption) + options := ContextOptions{ + ImageName: tt.args.imageName, + PortNumber: tt.args.portNumber, + ContainerId: tt.args.containerId, + Username: tt.args.username, + Password: tt.args.password, + PasswordEncryption: tt.args.passwordEncryption, + } + AddContextWithContainer(tt.args.contextName, options) }) }) } diff --git a/internal/config/endpoint-container.go b/internal/config/endpoint-container.go index dfc87f19..fe0f2ff4 100644 --- a/internal/config/endpoint-container.go +++ b/internal/config/endpoint-container.go @@ -43,7 +43,7 @@ func CurrentContextEndpointHasContainer() (exists bool) { currentContextName := config.CurrentContext if currentContextName == "" { - panic("currentContextName must not be empty") + return false } for _, c := range config.Contexts { diff --git a/internal/config/endpoint-container_test.go b/internal/config/endpoint-container_test.go index d7f2a02e..8a77a27b 100644 --- a/internal/config/endpoint-container_test.go +++ b/internal/config/endpoint-container_test.go @@ -11,14 +11,14 @@ import ( "testing" ) -// TestCurrentContextEndpointHasContainer verifies the function panics when +// TestCurrentContextEndpointHasContainer verifies the function returns false when // no current context func TestCurrentContextEndpointHasContainer(t *testing.T) { SetFileName(pal.FilenameInUserHomeDotDirectory( ".sqlcmd", "sqlconfig-TestCurrentContextEndpointHasContainer")) Clean() - assert.Panics(t, func() { CurrentContextEndpointHasContainer() }) + assert.False(t, CurrentContextEndpointHasContainer()) } func TestGetContainerId(t *testing.T) { diff --git a/internal/config/types.go b/internal/config/types.go new file mode 100644 index 00000000..41bd7e1d --- /dev/null +++ b/internal/config/types.go @@ -0,0 +1,10 @@ +package config + +type ContextOptions struct { + ImageName string + PortNumber int + ContainerId string + Username string + Password string + PasswordEncryption string +} diff --git a/internal/container/controller.go b/internal/container/controller.go index 0eea8fea..9367ec47 100644 --- a/internal/container/controller.go +++ b/internal/container/controller.go @@ -4,6 +4,7 @@ package container import ( + "archive/tar" "bufio" "bytes" "context" @@ -15,6 +16,7 @@ import ( "github.com/docker/go-connections/nat" specs "github.com/opencontainers/image-spec/specs-go/v1" "io" + "os" "path/filepath" "strconv" "strings" @@ -68,38 +70,31 @@ func (c Controller) EnsureImage(image string) (err error) { // the ID of the container. func (c Controller) ContainerRun( image string, - env []string, - port int, - name string, - hostname string, - architecture string, - os string, - command []string, - unitTestFailure bool, + options RunOptions, ) string { hostConfig := &container.HostConfig{ PortBindings: nat.PortMap{ nat.Port("1433/tcp"): []nat.PortBinding{ { HostIP: "0.0.0.0", - HostPort: strconv.Itoa(port), + HostPort: strconv.Itoa(options.Port), }, }, }, } platform := specs.Platform{ - Architecture: architecture, - OS: os, + Architecture: options.Architecture, + OS: options.Os, } resp, err := c.cli.ContainerCreate(context.Background(), &container.Config{ Tty: true, Image: image, - Cmd: command, - Env: env, - Hostname: hostname, - }, hostConfig, nil, &platform, name) + Cmd: options.Command, + Env: options.Env, + Hostname: options.Hostname, + }, hostConfig, nil, &platform, options.Name) checkErr(err) err = c.cli.ContainerStart( @@ -107,7 +102,7 @@ func (c Controller) ContainerRun( resp.ID, types.ContainerStartOptions{}, ) - if err != nil || unitTestFailure { + if err != nil || options.UnitTestFailure { // Remove the container, because we haven't persisted to config yet, so // uninstall won't work yet if resp.ID != "" { @@ -231,6 +226,41 @@ func (c Controller) ContainerFiles(id string, filespec string) (files []string) return strings.Split(string(stdout), "\n") } +func (c Controller) CopyFile(id string, src string, destFolder string) { + if id == "" { + panic("Must pass in non-empty id") + } + if src == "" { + panic("Must pass in non-empty src") + } + if destFolder == "" { + panic("Must pass in non-empty destFolder") + } + + _, f := filepath.Split(src) + h, err := os.ReadFile(src) + checkErr(err) + + // Create and add some files to the archive. + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + defer func() { + checkErr(tw.Close()) + }() + hdr := &tar.Header{ + Name: f, + Mode: 0600, + Size: int64(len(h)), + } + err = tw.WriteHeader(hdr) + checkErr(err) + _, err = tw.Write([]byte(h)) + checkErr(err) + + err = c.cli.CopyToContainer(context.Background(), id, destFolder, &buf, types.CopyToContainerOptions{}) + checkErr(err) +} + func (c Controller) DownloadFile(id string, src string, destFolder string) { if id == "" { panic("Must pass in non-empty id") @@ -242,10 +272,10 @@ func (c Controller) DownloadFile(id string, src string, destFolder string) { panic("Must pass in non-empty destFolder") } - cmd := []string{"mkdir", destFolder} - c.runCmdInContainer(id, cmd) + cmd := []string{"mkdir", "-p", destFolder} + c.RunCmdInContainer(id, cmd, ExecOptions{}) - _, file := filepath.Split(src) + _, file := filepath.Split(strings.Split(src, "?")[0]) // Wget the .bak file from the http src, and place it in /var/opt/sql/backup cmd = []string{ @@ -255,19 +285,25 @@ func (c Controller) DownloadFile(id string, src string, destFolder string) { src, } - c.runCmdInContainer(id, cmd) + c.RunCmdInContainer(id, cmd, ExecOptions{}) } -func (c Controller) runCmdInContainer(id string, cmd []string) ([]byte, []byte) { +func (c Controller) RunCmdInContainer( + id string, + cmd []string, + options ExecOptions, +) ([]byte, []byte) { trace("Running command in container: " + strings.Join(cmd, " ")) response, err := c.cli.ContainerExecCreate( context.Background(), id, types.ExecConfig{ + User: options.User, AttachStderr: true, AttachStdout: true, Cmd: cmd, + Env: options.Env, }, ) checkErr(err) diff --git a/internal/container/controller_test.go b/internal/container/controller_test.go index 347d7aec..bda30eca 100644 --- a/internal/container/controller_test.go +++ b/internal/container/controller_test.go @@ -33,17 +33,16 @@ func TestController_EnsureImage(t *testing.T) { c := NewController() err := c.EnsureImage(imageName) checkErr(err) - id := c.ContainerRun( - imageName, - []string{}, - port, - "", - "", - "amd64", - "linux", - []string{"ash", "-c", "echo 'Hello World'; sleep 3"}, - false, - ) + + runOptions := RunOptions{ + Env: []string{}, + Port: port, + Architecture: "amd64", + Os: "linux", + Command: []string{"ash", "-c", "echo 'Hello World'; sleep 3"}, + } + + id := c.ContainerRun(imageName, runOptions) c.ContainerRunning(id) c.ContainerWaitForLogEntry(id, "Hello World") c.ContainerExists(id) @@ -79,19 +78,13 @@ func TestController_ContainerRunFailure(t *testing.T) { c := NewController() - assert.Panics(t, func() { - c.ContainerRun( - imageName, - []string{}, - 0, - "", - "", - "amd64", - "linux", - []string{"ash", "-c", "echo 'Hello World'; sleep 1"}, - false, - ) - }) + runOptions := RunOptions{ + Architecture: "amd64", + Os: "linux", + Command: []string{"ash", "-c", "echo 'Hello World'; sleep 1"}, + } + + assert.Panics(t, func() { c.ContainerRun(imageName, runOptions) }) } func TestController_ContainerRunFailureCleanup(t *testing.T) { @@ -107,19 +100,13 @@ func TestController_ContainerRunFailureCleanup(t *testing.T) { c := NewController() - assert.Panics(t, func() { - c.ContainerRun( - imageName, - []string{}, - 0, - "", - "", - "amd64", - "linux", - []string{"ash", "-c", "echo 'Hello World'; sleep 1"}, - true, - ) - }) + runOptions := RunOptions{ + Architecture: "amd64", + Os: "linux", + Command: []string{"ash", "-c", "echo 'Hello World'; sleep 1"}, + UnitTestFailure: true, + } + assert.Panics(t, func() { c.ContainerRun(imageName, runOptions) }) } func TestController_ContainerStopNeg2(t *testing.T) { diff --git a/internal/container/types.go b/internal/container/types.go new file mode 100644 index 00000000..20fb7b64 --- /dev/null +++ b/internal/container/types.go @@ -0,0 +1,17 @@ +package container + +type RunOptions struct { + Env []string + Port int + Name string + Hostname string + Architecture string + Os string + Command []string + UnitTestFailure bool +} + +type ExecOptions struct { + User string + Env []string +} diff --git a/internal/databaseurl/error.go b/internal/databaseurl/error.go new file mode 100644 index 00000000..039500da --- /dev/null +++ b/internal/databaseurl/error.go @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package databaseurl + +var errorCallback func(err error) + +func checkErr(err error) { + errorCallback(err) +} diff --git a/internal/databaseurl/factory.go b/internal/databaseurl/factory.go new file mode 100644 index 00000000..d607e269 --- /dev/null +++ b/internal/databaseurl/factory.go @@ -0,0 +1,70 @@ +package databaseurl + +import ( + url2 "net/url" + "path/filepath" + "strings" +) + +func NewDatabaseUrl(url string) *DatabaseUrl { + trace("NewDatabaseUrl(" + url + ")") + + databaseUrl := DatabaseUrl{} + + // To enable URL.Parse, switch to / from \\ + url = strings.Replace(url, "\\", "/", -1) + + // Cope with a URL that in the local directory, so it can be URL.Parsed() + if !strings.Contains(url, "/") { + url = "./" + url + } + + // Cope with a file:// URL that in the local directory, so it can be URL.Parsed() + if strings.HasPrefix(strings.ToLower(url), "file://") && + !strings.Contains(url[7:], "/") { + url = "file://./" + url[7:] + } + + parsedUrl, err := url2.Parse(url) + checkErr(err) + + databaseUrl.URL = parsedUrl + + trace("databaseUrl.URL.Path: " + databaseUrl.URL.Path) + + databaseUrl.Filename = filepath.Base(databaseUrl.URL.Path) + databaseUrl.FileExtension = strings.TrimLeft(filepath.Ext(databaseUrl.Filename), ".") + + split := strings.Split(databaseUrl.URL.Path, ",") + if len(split) > 1 { + databaseUrl.DatabaseName = split[1] + + // Remove the database name (specified after the comma) from the URL, and reparse it + url = strings.Replace(url, ","+split[1], "", 1) + databaseUrl.URL, err = databaseUrl.URL.Parse(url) + checkErr(err) + + split := strings.Split(databaseUrl.FileExtension, ",") + databaseUrl.FileExtension = split[0] + + split = strings.Split(databaseUrl.Filename, ",") + databaseUrl.Filename = split[0] + } else { + databaseUrl.DatabaseName = strings.TrimSuffix( + databaseUrl.Filename, + "."+databaseUrl.FileExtension, + ) + } + + trace("databaseUrl.Filename: " + databaseUrl.Filename) + trace("databaseUrl.FileExtension: " + databaseUrl.FileExtension) + trace("databaseUrl.DatabaseName: " + databaseUrl.DatabaseName) + + databaseUrl.IsLocal = databaseUrl.URL.Scheme == "file" || len(databaseUrl.URL.Scheme) < 3 + + escapedDbName := strings.ReplaceAll(databaseUrl.DatabaseName, "'", "''") + databaseUrl.DatabaseNameAsTsqlIdentifier = strings.ReplaceAll(escapedDbName, "]", "]]") + databaseUrl.DatabaseNameAsNonTsqlIdentifier = strings.ReplaceAll(databaseUrl.DatabaseName, "]", "]]") + + return &databaseUrl +} diff --git a/internal/databaseurl/factory_test.go b/internal/databaseurl/factory_test.go new file mode 100644 index 00000000..80a07f04 --- /dev/null +++ b/internal/databaseurl/factory_test.go @@ -0,0 +1,52 @@ +package databaseurl + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNewDatabaseUrl(t *testing.T) { + tests := []struct { + url string + want string + }{ + {"https://example.com/testdb.bak,myDbName", "myDbName"}, + {"https://example.com/testdb.bak", "testdb"}, + {"https://example.com/test.foo", "test"}, + {"https://example.com/test.foo,test", "test"}, + {"https://example.com/test.7z,tsql_name", "tsql_name"}, + {"https://example.com/test.mdf,tsql_name?foo=bar", "tsql_name"}, + {"https://example.com/test.mdf,tsql_name#link?foo=bar", "tsql_name"}, + {"https://example.com/test.mdf?foo=bar", "test"}, + {"https://example.com/test.mdf#link?foo=bar", "test"}, + {"https://example.com/test,test", "test"}, + {"https://example.com,", ""}, + {"https://example.com", ""}, + {"test.7z,tsql_name", "tsql_name"}, + {"test.mdf,tsql_name", "tsql_name"}, + {"test.mdf", "test"}, + {"c:\\test.mdf", "test"}, + {"c:\\test.mdf,tsql_name", "tsql_name"}, + {"file://test.mdf,tsql_name", "tsql_name"}, + {"file://test.mdf", "test"}, + {"file://c:\\test.mdf", "test"}, + {"file://c:\\folder\\test.mdf", "test"}, + {"file://c:/test.mdf", "test"}, + {"file://c:/folder/test.mdf", "test"}, + {"file:\\test.mdf,tsql_name", "tsql_name"}, + {"file:\\test.mdf", "test"}, + {"file:\\c:\\test.mdf", "test"}, + {"file:\\c:\\folder\\test.mdf", "test"}, + {"file:\\c:/test.mdf", "test"}, + {"file:\\c:/folder/test.mdf", "test"}, + {"\\\\server\\share\\test.mdf", "test"}, + {"\\\\server\\share\\folder\\test.mdf", "test"}, + {"\\\\server\\share\\folder\\test.mdf,db_name", "db_name"}, + } + for _, tt := range tests { + t.Run("DatabaseURLTest-"+tt.url, func(t *testing.T) { + url := NewDatabaseUrl(tt.url) + assert.Equalf(t, tt.want, url.DatabaseName, "NewDatabaseUrl(%v)", url.DatabaseName) + }) + } +} diff --git a/internal/databaseurl/initialize.go b/internal/databaseurl/initialize.go new file mode 100644 index 00000000..26df3568 --- /dev/null +++ b/internal/databaseurl/initialize.go @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package databaseurl + +func init() { + Initialize( + func(err error) { + if err != nil { + panic(err) + } + }, + func(format string, a ...any) {}) +} + +func Initialize( + errorHandler func(err error), + traceHandler func(format string, a ...any)) { + + errorCallback = errorHandler + traceCallback = traceHandler +} diff --git a/internal/databaseurl/trace.go b/internal/databaseurl/trace.go new file mode 100644 index 00000000..05f2ce24 --- /dev/null +++ b/internal/databaseurl/trace.go @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package databaseurl + +var traceCallback func(format string, a ...any) + +func trace(format string, a ...any) { + traceCallback(format, a...) +} diff --git a/internal/databaseurl/type.go b/internal/databaseurl/type.go new file mode 100644 index 00000000..13bd9981 --- /dev/null +++ b/internal/databaseurl/type.go @@ -0,0 +1,20 @@ +package databaseurl + +import "net/url" + +type DatabaseUrl struct { + *url.URL + + Filename string + FileExtension string + IsLocal bool + + // DatabaseName returns the databaseName from --use arg + // It sets database name to the specified database name + // or in absence of it, it is set to the filename without + // extension. + DatabaseName string + + DatabaseNameAsTsqlIdentifier string + DatabaseNameAsNonTsqlIdentifier string +} diff --git a/cmd/modern/root/install/mssql-base_test.go b/internal/databaseurl/uri_test.go similarity index 78% rename from cmd/modern/root/install/mssql-base_test.go rename to internal/databaseurl/uri_test.go index a533eb2e..b076b7db 100644 --- a/cmd/modern/root/install/mssql-base_test.go +++ b/internal/databaseurl/uri_test.go @@ -1,11 +1,29 @@ -package install +package databaseurl import ( - "testing" - "github.com/stretchr/testify/assert" + "testing" ) +func TestExtractUrl(t *testing.T) { + type test struct { + inputURL string + expectedURL string + } + + tests := []test{ + {"https://example.com/testdb.bak,myDbName", "https://example.com/testdb.bak"}, + {"https://example.com/testdb.bak", "https://example.com/testdb.bak"}, + {"https://example.com,", "https://example.com,"}, + } + + for _, testcase := range tests { + u := NewDatabaseUrl(testcase.inputURL) + assert.Equal(t, testcase.expectedURL, u.String(), + "Extracted URL does not match expected URL") + } +} + func TestGetDbNameIfExists(t *testing.T) { type test struct { @@ -37,27 +55,11 @@ func TestGetDbNameIfExists(t *testing.T) { } for _, testcase := range tests { - dbname := parseDbName(testcase.input) - dbnameAsIdentifier := getDbNameAsIdentifier(dbname) - dbnameAsNonIdentifier := getDbNameAsNonIdentifier(dbname) - assert.Equal(t, testcase.expectedIdentifierOp, dbnameAsIdentifier, "Unexpected database name as identifier") - assert.Equal(t, testcase.expectedNonIdentifierOp, dbnameAsNonIdentifier, "Unexpected database name as non-identifier") - } -} + u := NewDatabaseUrl(testcase.input) -func TestExtractUrl(t *testing.T) { - type test struct { - inputURL string - expectedURL string - } - - tests := []test{ - {"https://example.com/testdb.bak,myDbName", "https://example.com/testdb.bak"}, - {"https://example.com/testdb.bak", "https://example.com/testdb.bak"}, - {"https://example.com,", "https://example.com,"}, - } - - for _, testcase := range tests { - assert.Equal(t, testcase.expectedURL, extractUrl(testcase.inputURL), "Extracted URL does not match expected URL") + assert.Equal(t, testcase.expectedIdentifierOp, u.DatabaseNameAsTsqlIdentifier, + "Unexpected database name as identifier") + assert.Equal(t, testcase.expectedNonIdentifierOp, u.DatabaseNameAsNonTsqlIdentifier, + "Unexpected database name as non-identifier") } } diff --git a/internal/intialize.go b/internal/intialize.go index e1bb9589..7be307e7 100644 --- a/internal/intialize.go +++ b/internal/intialize.go @@ -6,6 +6,7 @@ package internal import ( "github.com/microsoft/go-sqlcmd/internal/config" "github.com/microsoft/go-sqlcmd/internal/container" + "github.com/microsoft/go-sqlcmd/internal/databaseurl" "github.com/microsoft/go-sqlcmd/internal/http" "github.com/microsoft/go-sqlcmd/internal/io/file" "github.com/microsoft/go-sqlcmd/internal/net" @@ -45,5 +46,6 @@ func Initialize(options InitializeOptions) { secret.Initialize(options.ErrorHandler) net.Initialize(options.ErrorHandler, options.TraceHandler) http.Initialize(options.ErrorHandler, options.TraceHandler) + databaseurl.Initialize(options.ErrorHandler, options.TraceHandler) pal.Initialize(options.ErrorHandler, options.LineBreak) } diff --git a/internal/output/output.go b/internal/output/output.go index 6ce55f1b..7ab2c206 100644 --- a/internal/output/output.go +++ b/internal/output/output.go @@ -214,6 +214,7 @@ func (o Output) maskSecrets(text string) string { // Mask password from T/SQL e.g. ALTER LOGIN [sa] WITH PASSWORD = N'foo'; r := regexp.MustCompile(`(PASSWORD.*\s?=.*\s?N?')(.*)(')`) text = r.ReplaceAllString(text, "$1********$3") + return text } diff --git a/internal/sql/factory.go b/internal/sql/factory.go index 2713fee3..08ff2556 100644 --- a/internal/sql/factory.go +++ b/internal/sql/factory.go @@ -7,7 +7,7 @@ type SqlOptions struct { UnitTesting bool } -func New(options SqlOptions) Sql { +func NewSql(options SqlOptions) Sql { if options.UnitTesting { return &mock{} } else { diff --git a/internal/sql/mssql_test.go b/internal/sql/mssql_test.go index 169637ca..2a68e23a 100644 --- a/internal/sql/mssql_test.go +++ b/internal/sql/mssql_test.go @@ -77,7 +77,7 @@ func TestConnect(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mssql := New(SqlOptions{}) + mssql := NewSql(SqlOptions{}) // If test name ends in 'Panic' expect a Panic if strings.HasSuffix(tt.name, "Panic") { diff --git a/pkg/mssqlcontainer/ingest/factory.go b/pkg/mssqlcontainer/ingest/factory.go new file mode 100644 index 00000000..f8cf408b --- /dev/null +++ b/pkg/mssqlcontainer/ingest/factory.go @@ -0,0 +1,30 @@ +package ingest + +import ( + "github.com/microsoft/go-sqlcmd/internal/container" + "github.com/microsoft/go-sqlcmd/internal/databaseurl" + "github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer/ingest/location" + "github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer/ingest/mechanism" + "strings" +) + +func NewIngest(databaseUrl string, controller *container.Controller, options IngestOptions) Ingest { + url := databaseurl.NewDatabaseUrl(databaseUrl) + + return &ingest{ + url: url, + controller: controller, + location: location.NewLocation(url.IsLocal, url.String(), controller), + mechanism: mechanism.NewMechanism(url.FileExtension, options.Mechanism, controller), + } +} + +func ValidFileExtensions() string { + var extensions []string + + for _, m := range mechanism.FileTypes() { + extensions = append(extensions, m) + } + + return strings.Join(extensions, ", ") +} diff --git a/pkg/mssqlcontainer/ingest/ingest.go b/pkg/mssqlcontainer/ingest/ingest.go new file mode 100644 index 00000000..0ec3554c --- /dev/null +++ b/pkg/mssqlcontainer/ingest/ingest.go @@ -0,0 +1,114 @@ +package ingest + +import ( + "fmt" + "github.com/microsoft/go-sqlcmd/internal/container" + "github.com/microsoft/go-sqlcmd/internal/databaseurl" + "github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer/ingest/location" + "github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer/ingest/mechanism" +) + +type ingest struct { + url *databaseurl.DatabaseUrl + location location.Location + controller *container.Controller + mechanism mechanism.Mechanism + options mechanism.BringOnlineOptions + containerId string + query func(text string) +} + +func (i *ingest) IsRemoteUrl() bool { + return !i.location.IsLocal() +} + +func (i *ingest) UrlFilename() string { + return i.url.Filename +} + +func (i *ingest) OnlineMethod() string { + return i.mechanism.Name() +} + +func (i *ingest) DatabaseName() string { + return i.url.DatabaseName +} + +func (i *ingest) IsValidScheme() bool { + for _, s := range i.location.ValidSchemes() { + if s == i.url.Scheme { + return true + } + } + return false +} + +func (i *ingest) CopyToContainer(containerId string) { + destFolder := "/var/opt/mssql/backup" + + if i.mechanism != nil { + destFolder = i.mechanism.CopyToLocation() + } + if i.location == nil { + panic("location is nil, did you call NewIngest()?") + } + + i.containerId = containerId + i.location.CopyToContainer(containerId, destFolder) + i.options.Filename = i.url.Filename + + if i.options.Filename == "" { + panic("filename is empty") + } +} + +func (i *ingest) BringOnline(query func(string), username string, password string) { + if i.options.Filename == "" { + panic("filename is empty, did you call CopyToContainer()?") + } + if query == nil { + panic("query is nil") + } + if i.mechanism == nil { + panic("mechanism is nil") + } + + i.query = query + i.options.Username = username + i.options.Password = password + i.mechanism.BringOnline(i.url.DatabaseNameAsTsqlIdentifier, i.containerId, i.query, i.options) + i.setDefaultDatabase(username) +} + +func (i *ingest) setDefaultDatabase(username string) { + if i.query == nil { + panic("query is nil, did you call BringOnline()?") + } + + alterDefaultDb := fmt.Sprintf( + "ALTER LOGIN [%s] WITH DEFAULT_DATABASE = [%s]", + username, + i.url.DatabaseNameAsNonTsqlIdentifier) + i.query(alterDefaultDb) +} + +func (i *ingest) IsValidFileExtension() bool { + for _, m := range mechanism.FileTypes() { + if m == i.url.FileExtension { + return true + } + } + return false +} + +func (i *ingest) SourceFileExists() bool { + return i.location.Exists() +} + +func (i *ingest) UserProvidedFileExt() string { + return i.url.FileExtension +} + +func (i *ingest) ValidSchemes() []string { + return i.location.ValidSchemes() +} diff --git a/pkg/mssqlcontainer/ingest/interface.go b/pkg/mssqlcontainer/ingest/interface.go new file mode 100644 index 00000000..248bb475 --- /dev/null +++ b/pkg/mssqlcontainer/ingest/interface.go @@ -0,0 +1,18 @@ +package ingest + +type Ingest interface { + IsRemoteUrl() bool + IsValidScheme() bool + IsValidFileExtension() bool + + SourceFileExists() bool + DatabaseName() string + UrlFilename() string + OnlineMethod() string + UserProvidedFileExt() string + + CopyToContainer(containerId string) + BringOnline(query func(string), username string, password string) + + ValidSchemes() []string +} diff --git a/pkg/mssqlcontainer/ingest/location/factory.go b/pkg/mssqlcontainer/ingest/location/factory.go new file mode 100644 index 00000000..4ccb4768 --- /dev/null +++ b/pkg/mssqlcontainer/ingest/location/factory.go @@ -0,0 +1,19 @@ +package location + +import ( + "github.com/microsoft/go-sqlcmd/internal/container" +) + +func NewLocation(isLocal bool, uri string, controller *container.Controller) Location { + if isLocal { + return local{ + uri: uri, + controller: controller, + } + } else { + return remote{ + uri: uri, + controller: controller, + } + } +} diff --git a/pkg/mssqlcontainer/ingest/location/interface.go b/pkg/mssqlcontainer/ingest/location/interface.go new file mode 100644 index 00000000..b101b054 --- /dev/null +++ b/pkg/mssqlcontainer/ingest/location/interface.go @@ -0,0 +1,8 @@ +package location + +type Location interface { + Exists() bool + IsLocal() bool + CopyToContainer(containerId string, destFolder string) + ValidSchemes() []string +} diff --git a/pkg/mssqlcontainer/ingest/location/local.go b/pkg/mssqlcontainer/ingest/location/local.go new file mode 100644 index 00000000..584b6456 --- /dev/null +++ b/pkg/mssqlcontainer/ingest/location/local.go @@ -0,0 +1,52 @@ +package location + +import ( + "github.com/microsoft/go-sqlcmd/internal/container" + "github.com/microsoft/go-sqlcmd/internal/io/file" + "path/filepath" +) + +type local struct { + uri string + controller *container.Controller +} + +func (l local) Exists() bool { + return file.Exists(l.uri) +} + +func (l local) IsLocal() bool { + return true +} + +func (l local) ValidSchemes() []string { + return []string{"file"} +} + +func (l local) CopyToContainer(containerId string, destFolder string) { + l.controller.RunCmdInContainer( + containerId, + []string{"mkdir", "-p", destFolder}, + container.ExecOptions{}, + ) + + l.controller.CopyFile( + containerId, + l.uri, + destFolder, + ) + + _, filename := filepath.Split(l.uri) + + l.controller.RunCmdInContainer( + containerId, + []string{"chown", "mssql:root", destFolder + "/" + filename}, + container.ExecOptions{User: "root"}, + ) + + l.controller.RunCmdInContainer( + containerId, + []string{"chmod", "-o-r-u+rw-g+r", destFolder + "/" + filename}, + container.ExecOptions{User: "root"}, + ) +} diff --git a/pkg/mssqlcontainer/ingest/location/remote.go b/pkg/mssqlcontainer/ingest/location/remote.go new file mode 100644 index 00000000..23c5dfbc --- /dev/null +++ b/pkg/mssqlcontainer/ingest/location/remote.go @@ -0,0 +1,32 @@ +package location + +import ( + "github.com/microsoft/go-sqlcmd/internal/container" + "github.com/microsoft/go-sqlcmd/internal/http" +) + +type remote struct { + uri string + controller *container.Controller +} + +func (l remote) IsLocal() bool { + return false +} + +func (l remote) ValidSchemes() []string { + return []string{"https", "http"} +} + +// Verify the file exists at the URL +func (l remote) Exists() bool { + return http.UrlExists(l.uri) +} + +func (l remote) CopyToContainer(containerId string, destFolder string) { + l.controller.DownloadFile( + containerId, + l.uri, + destFolder, + ) +} diff --git a/pkg/mssqlcontainer/ingest/mechanism/attach.go b/pkg/mssqlcontainer/ingest/mechanism/attach.go new file mode 100644 index 00000000..94a34202 --- /dev/null +++ b/pkg/mssqlcontainer/ingest/mechanism/attach.go @@ -0,0 +1,68 @@ +package mechanism + +import ( + "fmt" + "github.com/microsoft/go-sqlcmd/internal/container" +) + +type attach struct { + controller *container.Controller + containerId string +} + +func (m *attach) Initialize(controller *container.Controller) { + m.controller = controller +} + +func (m *attach) CopyToLocation() string { + return "/var/opt/mssql/data" +} + +func (m *attach) Name() string { + return "attach" +} + +func (m *attach) FileTypes() []string { + return []string{"mdf"} +} + +func (m *attach) BringOnline( + databaseName string, + containerId string, + query func(string), + options BringOnlineOptions, +) { + text := `SET NOCOUNT ON; ` + + m.containerId = containerId + m.setFilePermissions(m.CopyToLocation() + "/" + options.Filename) + if options.LdfFilename == "" { + text += `CREATE DATABASE [%s] ON (FILENAME = '%s/%s') FOR ATTACH;` + query(fmt.Sprintf( + text, + databaseName, + m.CopyToLocation(), + options.Filename, + )) + } else { + m.setFilePermissions(m.CopyToLocation() + "/" + options.LdfFilename) + text += `CREATE DATABASE [%s] ON (FILENAME = '%s/%s'), (FILENAME = '%s/%s') FOR ATTACH;` + query(fmt.Sprintf( + text, + databaseName, + m.CopyToLocation(), + options.Filename, + m.CopyToLocation(), + options.LdfFilename, + )) + } +} + +func (m *attach) setFilePermissions(filename string) { + m.RunCommand([]string{"chown", "mssql:root", filename}) + m.RunCommand([]string{"chmod", "-o-r-u+rw-g+r", filename}) +} + +func (m *attach) RunCommand(s []string) ([]byte, []byte) { + return m.controller.RunCmdInContainer(m.containerId, s, container.ExecOptions{}) +} diff --git a/pkg/mssqlcontainer/ingest/mechanism/error.go b/pkg/mssqlcontainer/ingest/mechanism/error.go new file mode 100644 index 00000000..1ad16cc3 --- /dev/null +++ b/pkg/mssqlcontainer/ingest/mechanism/error.go @@ -0,0 +1,7 @@ +package mechanism + +var errorCallback func(err error) + +func checkErr(err error) { + errorCallback(err) +} diff --git a/pkg/mssqlcontainer/ingest/mechanism/factory.go b/pkg/mssqlcontainer/ingest/mechanism/factory.go new file mode 100644 index 00000000..2e40ddf2 --- /dev/null +++ b/pkg/mssqlcontainer/ingest/mechanism/factory.go @@ -0,0 +1,46 @@ +package mechanism + +import ( + "github.com/microsoft/go-sqlcmd/internal/container" +) + +func NewMechanism(fileExtension string, name string, controller *container.Controller) Mechanism { + trace("NewMechanism: fileExtension = %q, name = %q"+fileExtension, name) + for _, m := range mechanisms { + if m.Name() == name { + m.Initialize(controller) + + trace("Returning: %q", m.Name()) + + return m + } + } + + return NewMechanismByFileExt(fileExtension, controller) +} + +func NewMechanismByFileExt(fileExtension string, controller *container.Controller) Mechanism { + for _, m := range mechanisms { + for _, ext := range m.FileTypes() { + if ext == fileExtension { + m.Initialize(controller) + + trace("Returning: %q", m.Name()) + + return m + } + } + } + + trace("No mechanism found for file extension %q", fileExtension) + + return nil +} + +func Mechanisms() []string { + m := []string{} + for _, mechanism := range mechanisms { + m = append(m, mechanism.Name()) + } + return m +} diff --git a/pkg/mssqlcontainer/ingest/mechanism/initialize.go b/pkg/mssqlcontainer/ingest/mechanism/initialize.go new file mode 100644 index 00000000..e14358b2 --- /dev/null +++ b/pkg/mssqlcontainer/ingest/mechanism/initialize.go @@ -0,0 +1,19 @@ +package mechanism + +func init() { + Initialize( + func(err error) { + if err != nil { + panic(err) + } + }, + func(format string, a ...any) {}) +} + +func Initialize( + errorHandler func(err error), + traceHandler func(format string, a ...any)) { + + errorCallback = errorHandler + traceCallback = traceHandler +} diff --git a/pkg/mssqlcontainer/ingest/mechanism/interface.go b/pkg/mssqlcontainer/ingest/mechanism/interface.go new file mode 100644 index 00000000..264fa3e1 --- /dev/null +++ b/pkg/mssqlcontainer/ingest/mechanism/interface.go @@ -0,0 +1,11 @@ +package mechanism + +import "github.com/microsoft/go-sqlcmd/internal/container" + +type Mechanism interface { + FileTypes() []string + Initialize(controller *container.Controller) + CopyToLocation() string + BringOnline(databaseName string, containerId string, query func(string), options BringOnlineOptions) + Name() string +} diff --git a/pkg/mssqlcontainer/ingest/mechanism/mechanism.go b/pkg/mssqlcontainer/ingest/mechanism/mechanism.go new file mode 100644 index 00000000..32762ec4 --- /dev/null +++ b/pkg/mssqlcontainer/ingest/mechanism/mechanism.go @@ -0,0 +1,14 @@ +package mechanism + +var mechanisms = []Mechanism{ + &attach{}, + &restore{}, +} + +func FileTypes() []string { + fileTypes := []string{} + for _, m := range mechanisms { + fileTypes = append(fileTypes, m.FileTypes()...) + } + return fileTypes +} diff --git a/pkg/mssqlcontainer/ingest/mechanism/options.go b/pkg/mssqlcontainer/ingest/mechanism/options.go new file mode 100644 index 00000000..dad738c0 --- /dev/null +++ b/pkg/mssqlcontainer/ingest/mechanism/options.go @@ -0,0 +1,9 @@ +package mechanism + +type BringOnlineOptions struct { + Username string + Password string + + Filename string + LdfFilename string +} diff --git a/pkg/mssqlcontainer/ingest/mechanism/restore.go b/pkg/mssqlcontainer/ingest/mechanism/restore.go new file mode 100644 index 00000000..4e458f9c --- /dev/null +++ b/pkg/mssqlcontainer/ingest/mechanism/restore.go @@ -0,0 +1,85 @@ +package mechanism + +import ( + "fmt" + "github.com/microsoft/go-sqlcmd/internal/container" +) + +type restore struct { +} + +func (m *restore) Initialize(controller *container.Controller) { +} + +func (m *restore) CopyToLocation() string { + return "/var/opt/mssql/backup" +} + +func (m *restore) Name() string { + return "restore" +} + +func (m *restore) FileTypes() []string { + return []string{"bak"} +} + +func (m *restore) BringOnline(databaseName string, _ string, query func(string), options BringOnlineOptions) { + if options.Filename == "" { + panic("Filename is required for restore") + } + if databaseName == "" { + panic("databaseName is required for restore") + } + + query(fmt.Sprintf( + m.restoreStatement(), + m.CopyToLocation(), + options.Filename, + databaseName, + m.CopyToLocation(), + options.Filename, + )) +} + +func (m *restore) restoreStatement() string { + return `SET NOCOUNT ON; + +-- Build a SQL Statement to restore any .bak file to the Linux filesystem +DECLARE @sql NVARCHAR(max) + +-- This table definition works since SQL Server 2017, therefore +-- works for all SQL Server containers (which started in 2017) +DECLARE @fileListTable TABLE ( + [LogicalName] NVARCHAR(128), + [PhysicalName] NVARCHAR(260), + [Type] CHAR(1), + [FileGroupName] NVARCHAR(128), + [Size] NUMERIC(20,0), + [MaxSize] NUMERIC(20,0), + [FileID] BIGINT, + [CreateLSN] NUMERIC(25,0), + [DropLSN] NUMERIC(25,0), + [UniqueID] UNIQUEIDENTIFIER, + [ReadOnlyLSN] NUMERIC(25,0), + [ReadWriteLSN] NUMERIC(25,0), + [BackupSizeInBytes] BIGINT, + [SourceBlockSize] INT, + [FileGroupID] INT, + [LogGroupGUID] UNIQUEIDENTIFIER, + [DifferentialBaseLSN] NUMERIC(25,0), + [DifferentialBaseGUID] UNIQUEIDENTIFIER, + [IsReadOnly] BIT, + [IsPresent] BIT, + [TDEThumbprint] VARBINARY(32), + [SnapshotURL] NVARCHAR(360) +) + +INSERT INTO @fileListTable +EXEC('RESTORE FILELISTONLY FROM DISK = ''%s/%s''') +SET @sql = 'RESTORE DATABASE [%s] FROM DISK = ''%s/%s'' WITH ' +SELECT @sql = @sql + char(13) + ' MOVE ''' + LogicalName + ''' TO ''/var/opt/mssql/data/' + LogicalName + '.' + RIGHT(PhysicalName,CHARINDEX('\',PhysicalName)) + ''',' +FROM @fileListTable +WHERE IsPresent = 1 +SET @sql = SUBSTRING(@sql, 1, LEN(@sql)-1) +EXEC(@sql)` +} diff --git a/pkg/mssqlcontainer/ingest/mechanism/trace.go b/pkg/mssqlcontainer/ingest/mechanism/trace.go new file mode 100644 index 00000000..5917db05 --- /dev/null +++ b/pkg/mssqlcontainer/ingest/mechanism/trace.go @@ -0,0 +1,7 @@ +package mechanism + +var traceCallback func(format string, a ...any) + +func trace(format string, a ...any) { + traceCallback(format, a...) +} diff --git a/pkg/mssqlcontainer/ingest/type.go b/pkg/mssqlcontainer/ingest/type.go new file mode 100644 index 00000000..dfc6da9b --- /dev/null +++ b/pkg/mssqlcontainer/ingest/type.go @@ -0,0 +1,5 @@ +package ingest + +type IngestOptions struct { + Mechanism string +} diff --git a/pkg/mssqlcontainer/initialize.go b/pkg/mssqlcontainer/initialize.go new file mode 100644 index 00000000..c0b53cab --- /dev/null +++ b/pkg/mssqlcontainer/initialize.go @@ -0,0 +1,19 @@ +package mssqlcontainer + +import "github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer/ingest/mechanism" + +type InitializeOptions struct { + ErrorHandler func(error) + TraceHandler func(format string, a ...any) +} + +func Initialize(options InitializeOptions) { + if options.ErrorHandler == nil { + panic("ErrorHandler is nil") + } + if options.TraceHandler == nil { + panic("TraceHandler is nil") + } + + mechanism.Initialize(options.ErrorHandler, options.TraceHandler) +}