Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions psp_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,9 @@ struct psp_spi_tuple {
psp_spi spi;
};

struct psp_spi_addr_tuple {
struct psp_spi_tuple psp_spi;
struct in6_addr saddr;
};

#endif // PSP_KERNEL_H_
80 changes: 64 additions & 16 deletions psp_lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "common.h"
#include <time.h>
#include "psp_lib.h"
#include "socket.h"
#include "stream.h"
Expand All @@ -31,6 +32,8 @@ static struct listener *listeners;
static int nlisteners;
static int next_listener;
static int listener_space = 1;
static pthread_mutex_t psp_listeners_lock = PTHREAD_MUTEX_INITIALIZER;
static pthread_cond_t psp_listener_cv = PTHREAD_COND_INITIALIZER;

static void psp_resize_listeners_or_die(struct thread *ts) {
int new_max = listener_space << 2;
Expand Down Expand Up @@ -62,21 +65,46 @@ void psp_ctrl_client(int ctrl_conn, struct callbacks *cb) {

static int get_listen_fd(struct callbacks *cb, struct sockaddr_in6 *addr) {
int i;
int fd = -1;
int wait_rc = 0;
struct timespec ts;

/* key_server is single threaded exclusive user of nlisteners,
* so it is safe to access without locking.
/*
* locking here to protect shared listeners array, count, and cond variable
* while waiting for the requested port listener registration
*/
i = next_listener;
do {
if (addr->sin6_port == listeners[i].listenaddr.sin6_port) {
next_listener = (i + 1) % nlisteners;
return listeners[i].listenfd;
pthread_mutex_lock(&psp_listeners_lock);

clock_gettime(CLOCK_REALTIME, &ts);
ts.tv_sec += 5;

while (wait_rc == 0) {
if (listeners != NULL && nlisteners > 0) {
i = next_listener;
do {
if (addr->sin6_port == listeners[i].listenaddr.sin6_port) {
fd = listeners[i].listenfd;
next_listener = (i + 1) % nlisteners;
break;
}
i = (i + 1) % nlisteners;
} while (i != next_listener);

if (fd != -1) {
break;
}
}
wait_rc = pthread_cond_timedwait(&psp_listener_cv, &psp_listeners_lock, &ts);
}

i = (i + 1) % nlisteners;
} while (i != next_listener);
if (fd == -1) {
pthread_mutex_unlock(&psp_listeners_lock);
LOG_ERROR(cb, "get_listen_fd: Timed out waiting for listener socket for port %d to be registered.", ntohs(addr->sin6_port));
return -1;
}

return -1;
pthread_mutex_unlock(&psp_listeners_lock);
return fd;
}

static void *psp_key_server(void *arg)
Expand Down Expand Up @@ -142,13 +170,27 @@ static void *psp_key_server(void *arg)
LOG_FATAL(cb, "Port not found");
}

listen_tuple = req.client_tuple;
listen_size = sizeof(listen_tuple);

LOG_INFO(cb, "Setting listen key on fd %d", listenfd);

struct psp_spi_addr_tuple addr_tuple;
memset(&addr_tuple, 0, sizeof(addr_tuple));
addr_tuple.psp_spi = req.client_tuple;
addr_tuple.saddr = req.addr.sin6_addr;
socklen_t addr_size = sizeof(addr_tuple);

err = getsockopt(listenfd, IPPROTO_TCP, TCP_PSP_LISTENER,
&listen_tuple, &listen_size);
&addr_tuple, &addr_size);
if (err == 0) {
listen_tuple = addr_tuple.psp_spi;
} else if (errno == EINVAL) {
/* Fallback to V0 legacy struct for older kernels */
LOG_INFO(cb, "TCP_PSP_LISTENER size mismatch (AnyIP), retrying with legacy struct");
listen_tuple = req.client_tuple;
listen_size = sizeof(listen_tuple);
err = getsockopt(listenfd, IPPROTO_TCP, TCP_PSP_LISTENER,
&listen_tuple, &listen_size);
}

if (err < 0) {
LOG_FATAL(cb, "TCP_PSP_LISTENER failed: %s", strerror(errno));
}
Expand Down Expand Up @@ -254,7 +296,11 @@ void psp_post_listen(struct thread *t, int s, struct addrinfo *ai) {

LOG_INFO(t->cb, "registering PSP listener on port %d", ntohs(sin6->sin6_port));

pthread_mutex_lock(&psp_key_lock);
/*
* locking here to protect listeners count, array resize,
* and registration of the new listener socket
*/
pthread_mutex_lock(&psp_listeners_lock);

if ((nlisteners + 1) >= listener_space)
psp_resize_listeners_or_die(t);
Expand All @@ -263,5 +309,7 @@ void psp_post_listen(struct thread *t, int s, struct addrinfo *ai) {
listeners[nlisteners].listenaddr = *sin6;
nlisteners++;

pthread_mutex_unlock(&psp_key_lock);
pthread_cond_broadcast(&psp_listener_cv);

pthread_mutex_unlock(&psp_listeners_lock);
}