@@ -103,7 +103,11 @@ func main() {
103103 system , noLoad , debugMode bool
104104 )
105105
106- defaultSocketPath := func () string {
106+ envSocketPath := func () string {
107+ if val , ok := os .LookupEnv ("SSH_AUTH_SOCK" ); ok && socketPath == "" {
108+ return val
109+ }
110+
107111 dir := os .Getenv ("XDG_RUNTIME_DIR" )
108112 if dir == "" {
109113 dir = "/var/tmp"
@@ -113,7 +117,7 @@ func main() {
113117
114118 var sockets SocketSet
115119
116- flag .StringVar (& socketPath , "l" , defaultSocketPath , "path of the UNIX socket to listen on" )
120+ flag .StringVar (& socketPath , "l" , envSocketPath , "path of the UNIX socket to listen on" )
117121 flag .Var (& sockets , "A" , "fallback ssh-agent sockets" )
118122 flag .BoolVar (& swtpmFlag , "swtpm" , false , "use swtpm instead of actual tpm" )
119123 flag .BoolVar (& printSocketFlag , "print-socket" , false , "print path of UNIX socket to stdout" )
@@ -161,15 +165,6 @@ func main() {
161165 keyDir = utils .SSHDir ()
162166 }
163167
164- fi , err := os .Lstat (keyDir )
165- if err != nil {
166- slog .Error (err .Error ())
167- os .Exit (1 )
168- }
169- if fi .Mode ()& os .ModeSymlink == os .ModeSymlink {
170- slog .Info ("Not following symbolic link" , slog .String ("key_directory" , keyDir ))
171- }
172-
173168 if term .IsTerminal (int (os .Stdin .Fd ())) {
174169 slog .Info ("Warning: ssh-tpm-agent is meant to run as a background daemon." )
175170 slog .Info ("Running multiple instances is likely to lead to conflicts." )
@@ -187,44 +182,14 @@ func main() {
187182 agents = append (agents , sshagent .NewClient (conn ))
188183 }
189184
190- var listener * net.UnixListener
191-
192- if os .Getenv ("LISTEN_FDS" ) != "" {
193- if err != nil {
194- slog .Error (err .Error ())
195- os .Exit (1 )
196- }
197-
198- file := os .NewFile (uintptr (3 ), "ssh-tpm-agent.socket" )
199- fl , err := net .FileListener (file )
200- if err != nil {
201- slog .Error (err .Error ())
202- os .Exit (1 )
203- }
204- var ok bool
205- listener , ok = fl .(* net.UnixListener )
206- if ! ok {
207- slog .Error ("Socket-activation FD isn't a unix socket" )
208- os .Exit (1 )
209- }
210-
211- slog .Info ("Socket activated agent." )
212- } else {
213- os .Remove (socketPath )
214- if err := os .MkdirAll (filepath .Dir (socketPath ), 0o777 ); err != nil {
215- slog .Error ("Failed to create UNIX socket folder:" , err )
216- os .Exit (1 )
217- }
218- listener , err = net .ListenUnix ("unix" , & net.UnixAddr {Net : "unix" , Name : socketPath })
219- if err != nil {
220- slog .Error ("Failed to listen on UNIX socket:" , err )
221- os .Exit (1 )
222- }
223- slog .Info ("Listening on socket" , slog .String ("path" , socketPath ))
185+ listener , err := createListener (socketPath )
186+ if err != nil {
187+ slog .Error ("creating listener" , slog .String ("error" , err .Error ()))
188+ os .Exit (1 )
224189 }
225190
226- a := agent .NewAgent (listener ,
227- agents ,
191+ agent := agent .NewAgent (listener , agents ,
192+
228193 // TPM Callback
229194 func () (tpm transport.TPMCloser ) {
230195 // the agent will close the TPM after this is called
@@ -248,13 +213,48 @@ func main() {
248213 signal .Notify (c , syscall .SIGHUP )
249214 go func () {
250215 for range c {
251- a .Stop ()
216+ agent .Stop ()
252217 }
253218 }()
254219
255220 if ! noLoad {
256- a .LoadKeys (keyDir )
221+ if err := agent .LoadKeys (keyDir ); err != nil {
222+ slog .Error ("loading keys" , slog .String ("error" , err .Error ()))
223+ }
224+ }
225+
226+ agent .Wait ()
227+ }
228+
229+ func createListener (socketPath string ) (* net.UnixListener , error ) {
230+ if _ , ok := os .LookupEnv ("LISTEN_FDS" ); ok {
231+ f := os .NewFile (uintptr (3 ), "ssh-tpm-agent.socket" )
232+
233+ fListener , err := net .FileListener (f )
234+ if err != nil {
235+ return nil , err
236+ }
237+
238+ listener , ok := fListener .(* net.UnixListener )
239+ if ! ok {
240+ return nil , fmt .Errorf ("socket-activation file descriptor isn't an unix socket" )
241+ }
242+
243+ slog .Info ("Activated agent by socket" )
244+ return listener , nil
245+ }
246+
247+ _ = os .Remove (socketPath )
248+
249+ if err := os .MkdirAll (filepath .Dir (socketPath ), 0o770 ); err != nil {
250+ return nil , fmt .Errorf ("creating UNIX socket directory: %w" , err )
251+ }
252+
253+ listener , err := net .ListenUnix ("unix" , & net.UnixAddr {Net : "unix" , Name : socketPath })
254+ if err != nil {
255+ return nil , err
257256 }
258257
259- a .Wait ()
258+ slog .Info ("Listening on socket" , slog .String ("path" , socketPath ))
259+ return listener , nil
260260}
0 commit comments