Skip to content

Commit 25e8edb

Browse files
committed
Merge remote-tracking branch 'origin/pull/24'
* origin/pull/24: LoadKeys(): slog.Debug() + refactor ssh-tpm-agent: use SSH_AUTH_SOCK Refactor main() listener
2 parents 7f5985f + 763c2ae commit 25e8edb

File tree

2 files changed

+88
-76
lines changed

2 files changed

+88
-76
lines changed

agent/agent.go

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ func (a *Agent) List() ([]*agent.Key, error) {
127127
Comment: string(k.Comment),
128128
})
129129
}
130+
130131
return agentKeys, nil
131132
}
132133

@@ -268,35 +269,46 @@ func (a *Agent) Unlock(passphrase []byte) error {
268269
}
269270

270271
func LoadKeys(keyDir string) (map[string]*key.Key, error) {
271-
keys := map[string]*key.Key{}
272-
err := filepath.WalkDir(keyDir,
273-
func(path string, d fs.DirEntry, err error) error {
274-
if err != nil {
275-
return err
276-
}
277-
if d.IsDir() {
278-
return nil
279-
}
280-
if !strings.HasSuffix(path, "tpm") {
281-
return nil
282-
}
283-
f, err := os.ReadFile(path)
284-
if err != nil {
285-
return fmt.Errorf("failed reading %s", path)
286-
}
287-
k, err := key.DecodeKey(f)
288-
if err != nil {
289-
slog.Debug("not a TPM-sealed key", slog.String("key_path", path), slog.String("error", err.Error()))
290-
return nil
291-
}
292-
keys[k.Fingerprint()] = k
293-
return nil
294-
},
295-
)
272+
keyDir, err := filepath.EvalSymlinks(keyDir)
296273
if err != nil {
297274
return nil, err
298275
}
299-
return keys, nil
276+
277+
keys := make(map[string]*key.Key)
278+
279+
walkFunc := func(path string, d fs.DirEntry, err error) error {
280+
if err != nil {
281+
return err
282+
}
283+
284+
if d.IsDir() {
285+
return nil
286+
}
287+
288+
if !strings.HasSuffix(path, ".tpm") {
289+
slog.Debug("skipping key: does not have .tpm suffix", slog.String("name", path))
290+
return nil
291+
}
292+
293+
f, err := os.ReadFile(path)
294+
if err != nil {
295+
return fmt.Errorf("failed reading %s", path)
296+
}
297+
298+
k, err := key.DecodeKey(f)
299+
if err != nil {
300+
slog.Debug("not a TPM sealed key", slog.String("key_path", path), slog.String("error", err.Error()))
301+
return nil
302+
}
303+
304+
keys[k.Fingerprint()] = k
305+
306+
slog.Debug("added TPM key", slog.String("name", path))
307+
return nil
308+
}
309+
310+
err = filepath.WalkDir(keyDir, walkFunc)
311+
return keys, err
300312
}
301313

302314
func NewAgent(listener *net.UnixListener, agents []agent.ExtendedAgent, tpmFetch func() transport.TPMCloser, pin func(*key.Key) ([]byte, error)) *Agent {

cmd/ssh-tpm-agent/main.go

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,11 @@ func main() {
103103
system, noLoad, debugMode bool
104104
)
105105

106-
defaultSocketPath := func() string {
106+
envSocketPath := func() string {
107+
if val, ok := os.LookupEnv("SSH_AUTH_SOCK"); ok && socketPath == "" {
108+
return val
109+
}
110+
107111
dir := os.Getenv("XDG_RUNTIME_DIR")
108112
if dir == "" {
109113
dir = "/var/tmp"
@@ -113,7 +117,7 @@ func main() {
113117

114118
var sockets SocketSet
115119

116-
flag.StringVar(&socketPath, "l", defaultSocketPath, "path of the UNIX socket to listen on")
120+
flag.StringVar(&socketPath, "l", envSocketPath, "path of the UNIX socket to listen on")
117121
flag.Var(&sockets, "A", "fallback ssh-agent sockets")
118122
flag.BoolVar(&swtpmFlag, "swtpm", false, "use swtpm instead of actual tpm")
119123
flag.BoolVar(&printSocketFlag, "print-socket", false, "print path of UNIX socket to stdout")
@@ -161,15 +165,6 @@ func main() {
161165
keyDir = utils.SSHDir()
162166
}
163167

164-
fi, err := os.Lstat(keyDir)
165-
if err != nil {
166-
slog.Error(err.Error())
167-
os.Exit(1)
168-
}
169-
if fi.Mode()&os.ModeSymlink == os.ModeSymlink {
170-
slog.Info("Not following symbolic link", slog.String("key_directory", keyDir))
171-
}
172-
173168
if term.IsTerminal(int(os.Stdin.Fd())) {
174169
slog.Info("Warning: ssh-tpm-agent is meant to run as a background daemon.")
175170
slog.Info("Running multiple instances is likely to lead to conflicts.")
@@ -187,44 +182,14 @@ func main() {
187182
agents = append(agents, sshagent.NewClient(conn))
188183
}
189184

190-
var listener *net.UnixListener
191-
192-
if os.Getenv("LISTEN_FDS") != "" {
193-
if err != nil {
194-
slog.Error(err.Error())
195-
os.Exit(1)
196-
}
197-
198-
file := os.NewFile(uintptr(3), "ssh-tpm-agent.socket")
199-
fl, err := net.FileListener(file)
200-
if err != nil {
201-
slog.Error(err.Error())
202-
os.Exit(1)
203-
}
204-
var ok bool
205-
listener, ok = fl.(*net.UnixListener)
206-
if !ok {
207-
slog.Error("Socket-activation FD isn't a unix socket")
208-
os.Exit(1)
209-
}
210-
211-
slog.Info("Socket activated agent.")
212-
} else {
213-
os.Remove(socketPath)
214-
if err := os.MkdirAll(filepath.Dir(socketPath), 0o777); err != nil {
215-
slog.Error("Failed to create UNIX socket folder:", err)
216-
os.Exit(1)
217-
}
218-
listener, err = net.ListenUnix("unix", &net.UnixAddr{Net: "unix", Name: socketPath})
219-
if err != nil {
220-
slog.Error("Failed to listen on UNIX socket:", err)
221-
os.Exit(1)
222-
}
223-
slog.Info("Listening on socket", slog.String("path", socketPath))
185+
listener, err := createListener(socketPath)
186+
if err != nil {
187+
slog.Error("creating listener", slog.String("error", err.Error()))
188+
os.Exit(1)
224189
}
225190

226-
a := agent.NewAgent(listener,
227-
agents,
191+
agent := agent.NewAgent(listener, agents,
192+
228193
// TPM Callback
229194
func() (tpm transport.TPMCloser) {
230195
// the agent will close the TPM after this is called
@@ -248,13 +213,48 @@ func main() {
248213
signal.Notify(c, syscall.SIGHUP)
249214
go func() {
250215
for range c {
251-
a.Stop()
216+
agent.Stop()
252217
}
253218
}()
254219

255220
if !noLoad {
256-
a.LoadKeys(keyDir)
221+
if err := agent.LoadKeys(keyDir); err != nil {
222+
slog.Error("loading keys", slog.String("error", err.Error()))
223+
}
224+
}
225+
226+
agent.Wait()
227+
}
228+
229+
func createListener(socketPath string) (*net.UnixListener, error) {
230+
if _, ok := os.LookupEnv("LISTEN_FDS"); ok {
231+
f := os.NewFile(uintptr(3), "ssh-tpm-agent.socket")
232+
233+
fListener, err := net.FileListener(f)
234+
if err != nil {
235+
return nil, err
236+
}
237+
238+
listener, ok := fListener.(*net.UnixListener)
239+
if !ok {
240+
return nil, fmt.Errorf("socket-activation file descriptor isn't an unix socket")
241+
}
242+
243+
slog.Info("Activated agent by socket")
244+
return listener, nil
245+
}
246+
247+
_ = os.Remove(socketPath)
248+
249+
if err := os.MkdirAll(filepath.Dir(socketPath), 0o770); err != nil {
250+
return nil, fmt.Errorf("creating UNIX socket directory: %w", err)
251+
}
252+
253+
listener, err := net.ListenUnix("unix", &net.UnixAddr{Net: "unix", Name: socketPath})
254+
if err != nil {
255+
return nil, err
257256
}
258257

259-
a.Wait()
258+
slog.Info("Listening on socket", slog.String("path", socketPath))
259+
return listener, nil
260260
}

0 commit comments

Comments
 (0)