Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions configparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,31 @@ type DatabaseConfig struct {
ConnString string
PasswordVar string
JustUsePgPass bool
Sshconfig SSHConnConfig
}

func DecodeDatabases(crontab io.Reader, usepgpass bool) (map[string]string, error) {
func DecodeDatabases(crontab io.Reader, usepgpass bool) (map[string]DatabaseConfig, error) {
var configs map[string]DatabaseConfig
decoder := toml.NewDecoder(crontab)
err := decoder.Decode(&configs)
if err != nil {
return nil, err
}
databases := map[string]string{}
databases := map[string]DatabaseConfig{}
for key, config := range configs {
if config.ConnString == "" {
return nil, fmt.Errorf("Missing connstring in database %s", key)
}
if usepgpass || config.JustUsePgPass {
databases[key] = strings.Replace(config.ConnString, ":$password", "", 1)
} else if config.PasswordVar == "" {
databases[key] = config.ConnString
} else {
config.ConnString = strings.Replace(config.ConnString, ":$password", "", 1)
} else if config.PasswordVar != "" {
password := os.Getenv(config.PasswordVar)
if password == "" {
return nil, fmt.Errorf("Injected passwordvar %s is empty!", config.PasswordVar)
}
databases[key] = strings.Replace(config.ConnString, "$password", password, 1)
config.ConnString = strings.Replace(config.ConnString, "$password", password, 1)
}
databases[key] = config
}
return databases, nil
}
Expand All @@ -65,18 +65,18 @@ func DecodeJobs(crontab io.Reader) (jobconfigs map[string]JobConfig, err error)
return jobconfigs, nil
}

func CreateJobs(configs map[string]JobConfig, databases map[string]string, monitor Monitor) ([]Job, error) {
func CreateJobs(configs map[string]JobConfig, databases map[string]DatabaseConfig, monitor Monitor) ([]Job, error) {
jobs := []Job{}
for name, config := range configs {
schedule, err := cron.ParseStandard(config.CronSchedule)
if err != nil {
return nil, fmt.Errorf("Cron schedule error: %w", err)
}
connstr, ok := databases[config.Database]
dbconfig, ok := databases[config.Database]
if !ok {
return nil, fmt.Errorf("Missing Db: The database %s specified by job %s does not seem to exist!", config.Database, name)
}
job, err := CreateJob(name, config.Database, schedule, connstr, config.Query, config.JobMiscOptions, monitor)
job, err := CreateJob(name, config.Database, schedule, dbconfig.ConnString, config.Query, dbconfig.Sshconfig, config.JobMiscOptions, monitor)
if err != nil {
return nil, err
}
Expand Down
13 changes: 12 additions & 1 deletion pgxjob.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package main

import (
"context"
"fmt"
"log"
"net"
"slices"
"strings"
"time"
Expand All @@ -26,7 +28,7 @@ type Job struct {
valid bool
}

func CreateJob(jobname, dbname string, s Schedule, target, query string, misc JobMiscOptions, monitor Monitor) (j Job, err error) {
func CreateJob(jobname, dbname string, s Schedule, target, query string, ssh SSHConnConfig, misc JobMiscOptions, monitor Monitor) (j Job, err error) {
if jobname == "" || dbname == "" || s == nil {
return j, fmt.Errorf("Received nil input(s) when creating %s", jobname)
}
Expand All @@ -47,6 +49,15 @@ func CreateJob(jobname, dbname string, s Schedule, target, query string, misc Jo
if config.ConnectTimeout == time.Duration(0) { // Default to 50 seconds if no finite timeout is provided
config.ConnectTimeout = 50 * time.Second // via the standard pgx & psql PGCONNECT_TIMEOUT env var
}
if ssh.Host != "" {
client, err := NewSSHClient(&ssh)
if err != nil {
return j, err
}
config.DialFunc = func(ctx context.Context, network string, addr string) (net.Conn, error) {
return client.Dial(network, addr)
}
}

return Job{
JobName: jobname,
Expand Down
58 changes: 58 additions & 0 deletions ssh_tunnel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package main

import (
"net"
"os"

"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/knownhosts"
)

type SSHConnConfig struct {
Host string
Port string
User string
Knownhosts string
Keyfile string
}

func NewSSHClient(config *SSHConnConfig) (*ssh.Client, error) {
sshConfig := &ssh.ClientConfig{
User: config.User,
}

if auth := SSHAgent(); auth != nil {
sshConfig.Auth = append(sshConfig.Auth, auth)
}

if hostKeyCallback, err := knownhosts.New(config.Knownhosts); err == nil {
sshConfig.HostKeyCallback = hostKeyCallback
}
if config.Keyfile != "" {
if auth := PrivateKey(config.Keyfile); auth != nil {
sshConfig.Auth = append(sshConfig.Auth, auth)
}
}

return ssh.Dial("tcp", net.JoinHostPort(config.Host, config.Port), sshConfig)
}

func SSHAgent() ssh.AuthMethod {
if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
}
return nil
}

func PrivateKey(path string) ssh.AuthMethod {
key, err := os.ReadFile(path)
if err != nil {
return nil
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
return nil
}
return ssh.PublicKeys(signer)
}
Loading