diff --git a/psp_kernel.h b/psp_kernel.h index e879a30..7b1b4c0 100644 --- a/psp_kernel.h +++ b/psp_kernel.h @@ -46,4 +46,9 @@ struct psp_spi_tuple { psp_spi spi; }; +struct psp_spi_addr_tuple { + struct psp_spi_tuple psp_spi; + struct in6_addr saddr; +}; + #endif // PSP_KERNEL_H_ diff --git a/psp_lib.c b/psp_lib.c index d1c064f..6c08b08 100644 --- a/psp_lib.c +++ b/psp_lib.c @@ -15,6 +15,7 @@ */ #include "common.h" +#include #include "psp_lib.h" #include "socket.h" #include "stream.h" @@ -31,6 +32,8 @@ static struct listener *listeners; static int nlisteners; static int next_listener; static int listener_space = 1; +static pthread_mutex_t psp_listeners_lock = PTHREAD_MUTEX_INITIALIZER; +static pthread_cond_t psp_listener_cv = PTHREAD_COND_INITIALIZER; static void psp_resize_listeners_or_die(struct thread *ts) { int new_max = listener_space << 2; @@ -62,21 +65,46 @@ void psp_ctrl_client(int ctrl_conn, struct callbacks *cb) { static int get_listen_fd(struct callbacks *cb, struct sockaddr_in6 *addr) { int i; + int fd = -1; + int wait_rc = 0; + struct timespec ts; - /* key_server is single threaded exclusive user of nlisteners, - * so it is safe to access without locking. + /* + * locking here to protect shared listeners array, count, and cond variable + * while waiting for the requested port listener registration */ - i = next_listener; - do { - if (addr->sin6_port == listeners[i].listenaddr.sin6_port) { - next_listener = (i + 1) % nlisteners; - return listeners[i].listenfd; + pthread_mutex_lock(&psp_listeners_lock); + + clock_gettime(CLOCK_REALTIME, &ts); + ts.tv_sec += 5; + + while (wait_rc == 0) { + if (listeners != NULL && nlisteners > 0) { + i = next_listener; + do { + if (addr->sin6_port == listeners[i].listenaddr.sin6_port) { + fd = listeners[i].listenfd; + next_listener = (i + 1) % nlisteners; + break; + } + i = (i + 1) % nlisteners; + } while (i != next_listener); + + if (fd != -1) { + break; + } } + wait_rc = pthread_cond_timedwait(&psp_listener_cv, &psp_listeners_lock, &ts); + } - i = (i + 1) % nlisteners; - } while (i != next_listener); + if (fd == -1) { + pthread_mutex_unlock(&psp_listeners_lock); + LOG_ERROR(cb, "get_listen_fd: Timed out waiting for listener socket for port %d to be registered.", ntohs(addr->sin6_port)); + return -1; + } - return -1; + pthread_mutex_unlock(&psp_listeners_lock); + return fd; } static void *psp_key_server(void *arg) @@ -142,13 +170,27 @@ static void *psp_key_server(void *arg) LOG_FATAL(cb, "Port not found"); } - listen_tuple = req.client_tuple; - listen_size = sizeof(listen_tuple); - LOG_INFO(cb, "Setting listen key on fd %d", listenfd); + struct psp_spi_addr_tuple addr_tuple; + memset(&addr_tuple, 0, sizeof(addr_tuple)); + addr_tuple.psp_spi = req.client_tuple; + addr_tuple.saddr = req.addr.sin6_addr; + socklen_t addr_size = sizeof(addr_tuple); + err = getsockopt(listenfd, IPPROTO_TCP, TCP_PSP_LISTENER, - &listen_tuple, &listen_size); + &addr_tuple, &addr_size); + if (err == 0) { + listen_tuple = addr_tuple.psp_spi; + } else if (errno == EINVAL) { + /* Fallback to V0 legacy struct for older kernels */ + LOG_INFO(cb, "TCP_PSP_LISTENER size mismatch (AnyIP), retrying with legacy struct"); + listen_tuple = req.client_tuple; + listen_size = sizeof(listen_tuple); + err = getsockopt(listenfd, IPPROTO_TCP, TCP_PSP_LISTENER, + &listen_tuple, &listen_size); + } + if (err < 0) { LOG_FATAL(cb, "TCP_PSP_LISTENER failed: %s", strerror(errno)); } @@ -254,7 +296,11 @@ void psp_post_listen(struct thread *t, int s, struct addrinfo *ai) { LOG_INFO(t->cb, "registering PSP listener on port %d", ntohs(sin6->sin6_port)); - pthread_mutex_lock(&psp_key_lock); + /* + * locking here to protect listeners count, array resize, + * and registration of the new listener socket + */ + pthread_mutex_lock(&psp_listeners_lock); if ((nlisteners + 1) >= listener_space) psp_resize_listeners_or_die(t); @@ -263,5 +309,7 @@ void psp_post_listen(struct thread *t, int s, struct addrinfo *ai) { listeners[nlisteners].listenaddr = *sin6; nlisteners++; - pthread_mutex_unlock(&psp_key_lock); + pthread_cond_broadcast(&psp_listener_cv); + + pthread_mutex_unlock(&psp_listeners_lock); }