Skip to content

Commit

Permalink
The trusted client feature initial checkin
Browse files Browse the repository at this point in the history
Signed-off-by: Shivshankar-Reddy <[email protected]>
  • Loading branch information
Shivshankar-Reddy committed Jun 17, 2024
1 parent 5a51bf5 commit 914a7c7
Show file tree
Hide file tree
Showing 13 changed files with 250 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ endif
ENGINE_NAME=valkey
SERVER_NAME=$(ENGINE_NAME)-server$(PROG_SUFFIX)
ENGINE_SENTINEL_NAME=$(ENGINE_NAME)-sentinel$(PROG_SUFFIX)
ENGINE_SERVER_OBJ=threads_mngr.o adlist.o quicklist.o ae.o anet.o dict.o kvstore.o server.o sds.o zmalloc.o lzf_c.o lzf_d.o pqsort.o zipmap.o sha1.o ziplist.o release.o networking.o util.o object.o db.o replication.o rdb.o t_string.o t_list.o t_set.o t_zset.o t_hash.o config.o aof.o pubsub.o multi.o debug.o sort.o intset.o syncio.o cluster.o cluster_legacy.o crc16.o endianconv.o slowlog.o eval.o bio.o rio.o rand.o memtest.o syscheck.o crcspeed.o crccombine.o crc64.o bitops.o sentinel.o notify.o setproctitle.o blocked.o hyperloglog.o latency.o sparkline.o valkey-check-rdb.o valkey-check-aof.o geo.o lazyfree.o module.o evict.o expire.o geohash.o geohash_helper.o childinfo.o defrag.o siphash.o rax.o t_stream.o listpack.o localtime.o lolwut.o lolwut5.o lolwut6.o acl.o tracking.o socket.o tls.o sha256.o timeout.o setcpuaffinity.o monotonic.o mt19937-64.o resp_parser.o call_reply.o script_lua.o script.o functions.o function_lua.o commands.o strl.o connection.o unix.o logreqres.o
ENGINE_SERVER_OBJ=threads_mngr.o adlist.o quicklist.o ae.o anet.o dict.o kvstore.o server.o sds.o zmalloc.o lzf_c.o lzf_d.o pqsort.o zipmap.o sha1.o ziplist.o release.o networking.o util.o object.o db.o replication.o rdb.o t_string.o t_list.o t_set.o t_zset.o t_hash.o config.o aof.o pubsub.o multi.o debug.o sort.o intset.o syncio.o cluster.o cluster_legacy.o crc16.o endianconv.o slowlog.o eval.o bio.o rio.o rand.o memtest.o syscheck.o crcspeed.o crccombine.o crc64.o bitops.o sentinel.o notify.o setproctitle.o blocked.o hyperloglog.o latency.o sparkline.o valkey-check-rdb.o valkey-check-aof.o geo.o lazyfree.o module.o evict.o expire.o geohash.o geohash_helper.o childinfo.o defrag.o siphash.o rax.o t_stream.o listpack.o localtime.o lolwut.o lolwut5.o lolwut6.o acl.o tracking.o socket.o tls.o sha256.o timeout.o setcpuaffinity.o monotonic.o mt19937-64.o resp_parser.o call_reply.o script_lua.o script.o functions.o function_lua.o commands.o strl.o connection.o unix.o logreqres.o trusted_network.o
ENGINE_CLI_NAME=$(ENGINE_NAME)-cli$(PROG_SUFFIX)
ENGINE_CLI_OBJ=anet.o adlist.o dict.o valkey-cli.o zmalloc.o release.o ae.o serverassert.o crcspeed.o crccombine.o crc64.o siphash.o crc16.o monotonic.o cli_common.o mt19937-64.o strl.o cli_commands.o
ENGINE_BENCHMARK_NAME=$(ENGINE_NAME)-benchmark$(PROG_SUFFIX)
Expand Down
6 changes: 4 additions & 2 deletions src/anet.c
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,9 @@ static int anetGenericAccept(char *err, int s, struct sockaddr *sa, socklen_t *l
}

/* Accept a connection and also make sure the socket is non-blocking, and CLOEXEC.
* returns the new socket FD, or -1 on error. */
int anetTcpAccept(char *err, int serversock, char *ip, size_t ip_len, int *port) {
* returns the new socket FD, or -1 on error.
* If client_addr is not null, it will receive a copy of the client's sockaddr_storage structure. */
int anetTcpAccept(char *err, int serversock, char *ip, size_t ip_len, int *port, struct sockaddr_storage *client_addr) {
int fd;
struct sockaddr_storage sa;
socklen_t salen = sizeof(sa);
Expand All @@ -645,6 +646,7 @@ int anetTcpAccept(char *err, int serversock, char *ip, size_t ip_len, int *port)
if (ip) inet_ntop(AF_INET6, (void *)&(s->sin6_addr), ip, ip_len);
if (port) *port = ntohs(s->sin6_port);
}
if (client_addr) { memcpy(client_addr, &sa, sizeof(sa)); }
return fd;
}

Expand Down
4 changes: 3 additions & 1 deletion src/anet.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@
#undef ip_len
#endif

struct sockaddr_storage;

int anetTcpNonBlockConnect(char *err, const char *addr, int port);
int anetTcpNonBlockBestEffortBindConnect(char *err, const char *addr, int port, const char *source_addr);
int anetResolve(char *err, char *host, char *ipbuf, size_t ipbuf_len, int flags);
int anetTcpServer(char *err, int port, char *bindaddr, int backlog);
int anetTcp6Server(char *err, int port, char *bindaddr, int backlog);
int anetUnixServer(char *err, char *path, mode_t perm, int backlog);
int anetTcpAccept(char *err, int serversock, char *ip, size_t ip_len, int *port);
int anetTcpAccept(char *err, int serversock, char *ip, size_t ip_len, int *port, struct sockaddr_storage *client_addr);
int anetUnixAccept(char *err, int serversock);
int anetNonBlock(char *err, int fd);
int anetBlock(char *err, int fd);
Expand Down
2 changes: 1 addition & 1 deletion src/cluster_legacy.c
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ void clusterAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
if (server.primary_host == NULL && server.loading) return;

while (max--) {
cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport);
cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport, NULL);
if (cfd == ANET_ERR) {
if (errno != EWOULDBLOCK) serverLog(LL_VERBOSE, "Error accepting cluster node: %s", server.neterr);
return;
Expand Down
80 changes: 80 additions & 0 deletions src/config.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <string.h>
#include <locale.h>
#include <ctype.h>
#include <arpa/inet.h>

/*-----------------------------------------------------------------------------
* Config file name-value maps.
Expand Down Expand Up @@ -2853,6 +2854,84 @@ static sds getConfigNotifyKeyspaceEventsOption(standardConfig *config) {
return keyspaceEventsFlagsToString(server.notify_keyspace_events);
}

static int setConfigTrustedAddresses(standardConfig *config, sds* argv, int argc, const char **err) {
UNUSED(config);
int j;
int skip = 0;

if ((unsigned int)argc > server.maxclients || server.trustedIPCount > server.maxclients) {
*err = "Too many addresses specified.";
return 0;
}

/* A single empty argument is treated as a zero address count */
if (argc == 1 && sdslen(argv[0]) == 0) argc = 0;

for (j = 0; j < argc; j++) {
const char* ip = zstrdup(argv[j]);
in_addr_t addr = inet_addr(ip);
if (addr == 0 || addr == INADDR_NONE) {
if(!ip) sds_free((void *)ip);
*err = "Invalid adress is specified.";
return 0;
}
if (server.trustedIPCount && checkTrustedIP(addr)) {
serverLog(LL_NOTICE, "Do not add, IP is exist in the list");
skip++;
continue;
}
server.trustedIPList = zrealloc(server.trustedIPList,
sizeof(in_addr_t) * (server.trustedIPCount + j - skip + 1));
server.trustedIPList[j + server.trustedIPCount - skip] = addr;
sds_free((void *)ip);
}
server.trustedIPCount = server.trustedIPCount + j - skip;
valkeySortIP(server.trustedIPList, server.trustedIPCount);

return 1;
}

static sds getConfigTrustedAddresses(standardConfig *config) {
UNUSED(config);

unsigned int i;
sds reply = sdsempty();

for (i = 0; i < server.trustedIPCount; i++) {
struct in_addr addr = { 0 };
addr.s_addr = server.trustedIPList[i];
reply = sdscat(reply, inet_ntoa(addr));
if (i != (server.trustedIPCount - 1)) {
reply = sdscat(reply," ");
}
}
return reply;
}

/* Rewrite the trusted-addresses option. Rewrites trusted-addresses parameters,
* or simply return to avoid the defaults from being used.*/
void rewriteConfigTrustedAdresses(standardConfig *config, const char *name, struct rewriteConfigState *state) {
UNUSED(config);
sds line = sdsempty();

if (!server.trustedIPCount) {
sdsfree(line);
return;
} else {
line = sdscat(line,name);
line = sdscat(line," ");
for (unsigned int j = 0; j < server.trustedIPCount; j++) {
struct in_addr addr = { 0 };
addr.s_addr = server.trustedIPList[j];
line = sdscat(line, inet_ntoa(addr));
if (j != (server.trustedIPCount - 1)) {
line = sdscat(line," ");
}
}
}
rewriteConfigRewriteLine(state, name, line, 1);
}

static int setConfigBindOption(standardConfig *config, sds *argv, int argc, const char **err) {
UNUSED(config);
int j;
Expand Down Expand Up @@ -3238,6 +3317,7 @@ standardConfig static_configs[] = {
createSpecialConfig("bind", NULL, MODIFIABLE_CONFIG | MULTI_ARG_CONFIG, setConfigBindOption, getConfigBindOption, rewriteConfigBindOption, applyBind),
createSpecialConfig("replicaof", "slaveof", IMMUTABLE_CONFIG | MULTI_ARG_CONFIG, setConfigReplicaOfOption, getConfigReplicaOfOption, rewriteConfigReplicaOfOption, NULL),
createSpecialConfig("latency-tracking-info-percentiles", NULL, MODIFIABLE_CONFIG | MULTI_ARG_CONFIG, setConfigLatencyTrackingInfoPercentilesOutputOption, getConfigLatencyTrackingInfoPercentilesOutputOption, rewriteConfigLatencyTrackingInfoPercentilesOutputOption, NULL),
createSpecialConfig("trusted-addresses", NULL, MODIFIABLE_CONFIG | MULTI_ARG_CONFIG, setConfigTrustedAddresses, getConfigTrustedAddresses, rewriteConfigTrustedAdresses, NULL),

/* NULL Terminator, this is dropped when we convert to the runtime array. */
{NULL}
Expand Down
36 changes: 34 additions & 2 deletions src/networking.c
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ void clientAcceptHandler(connection *conn) {
moduleFireServerEvent(VALKEYMODULE_EVENT_CLIENT_CHANGE, VALKEYMODULE_SUBEVENT_CLIENT_CHANGE_CONNECTED, c);
}

void acceptCommonHandler(connection *conn, int flags, char *ip) {
void acceptCommonHandler(connection *conn, int flags, char *ip, const struct sockaddr_storage *sa) {
client *c;
UNUSED(ip);

Expand All @@ -1342,13 +1342,45 @@ void acceptCommonHandler(connection *conn, int flags, char *ip) {
connClose(conn);
return;
}
in_addr_t ip_addr = 0;
if (sa != NULL && sa->ss_family == AF_INET)
ip_addr = ((struct sockaddr_in *)sa)->sin_addr.s_addr;

if (server.trustedIPCount && ip_addr) {
if (/*!isTrustedNetwork(c) &&*/ !checkTrustedIP(ip_addr)) {
serverLog(LL_VERBOSE, "Access denied as connection is not from trusted source");

char *err = "-ERR client's IP is not found in trusted-addresses list, access denied\r\n";

/* That's a best effort error message, don't check write errors */
if (connWrite(conn,err,strlen(err)) == -1) {
/* Nothing to do, Just to avoid the warning... */
}
server.stat_rejected_conn++;
connClose(conn);
return;
}
} else if (server.trustedIPCount) {
serverLog(LL_VERBOSE, "Source address is not valid, client id");

char *err = "-ERR unable to retrieve valid IP address\r\n";

/* That's a best effort error message, don't check write errors */
if (connWrite(conn,err,strlen(err)) == -1) {
/* Nothing to do, Just to avoid the warning... */
}
server.stat_rejected_conn++;
connClose(conn);
return;
}

/* Limit the number of connections we take at the same time.
*
* Admission control will happen before a client is created and connAccept()
* called, because we don't want to even start transport-level negotiation
* if rejected. */
if (listLength(server.clients) + getClusterConnectionsCount() >= server.maxclients) {
if (!checkTrustedIP(ip_addr) &&
(listLength(server.clients) + getClusterConnectionsCount() >= server.maxclients)) {
char *err;
if (server.cluster_enabled)
err = "-ERR max number of clients + cluster "
Expand Down
21 changes: 21 additions & 0 deletions src/server.c
Original file line number Diff line number Diff line change
Expand Up @@ -1970,6 +1970,8 @@ void initServerConfig(void) {
server.bindaddr_count = CONFIG_DEFAULT_BINDADDR_COUNT;
for (j = 0; j < CONFIG_DEFAULT_BINDADDR_COUNT; j++) server.bindaddr[j] = zstrdup(default_bindaddr[j]);
memset(server.listeners, 0x00, sizeof(server.listeners));
server.host_machine_ip = 0;
server.host_machine_netmask = 0;
server.active_expire_enabled = 1;
server.lazy_expire_disabled = 0;
server.skip_checksum_validation = 0;
Expand Down Expand Up @@ -2562,6 +2564,25 @@ void initServer(void) {
server.reply_buffer_resizing_enabled = 1;
server.client_mem_usage_buckets = NULL;
resetReplicationBuffer();
char *default_bindaddr[CONFIG_DEFAULT_BINDADDR_COUNT] = CONFIG_DEFAULT_BINDADDR;
if (server.bindaddr_count > 0 && strcmp(server.bindaddr[0],default_bindaddr[0])) {
serverLog(LL_WARNING, "bind adrs.%d : %s",server.bindaddr_count, server.bindaddr[0]);
server.host_machine_ip = inet_addr(server.bindaddr[0]);
} else {
serverLog(LL_WARNING, "local loopback.");
server.host_machine_ip = inet_addr("127.0.0.1");
}

if (server.host_machine_ip <= 0) {
serverLog(LL_WARNING, "Can not get host machine network ip, exiting.");
exit(1);
}

server.host_machine_netmask = getIPv4Netmask(server.host_machine_ip);
if (server.host_machine_netmask <= 0) {
serverLog(LL_WARNING, "Can not get host machine network netmask, exiting.");
exit(1);
}

/* Make sure the locale is set on startup based on the config file. */
if (setlocale(LC_COLLATE, server.locale_collate) == NULL) {
Expand Down
14 changes: 13 additions & 1 deletion src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,10 @@ struct valkeyServer {
list *replicas, *monitors; /* List of replicas and MONITORs */
client *current_client; /* The client that triggered the command execution (External or AOF). */
client *executing_client; /* The client executing the current command (possibly script or module). */
in_addr_t host_machine_ip; /*Listening ip for host machine network*/
in_addr_t host_machine_netmask; /*Netmask for host_machine_ip*/
in_addr_t *trustedIPList;
unsigned int trustedIPCount;

#ifdef LOG_REQ_RES
char *req_res_logfile; /* Path of log file for logging all requests and their replies. If NULL, no logging will be
Expand Down Expand Up @@ -2620,6 +2624,7 @@ void dictVanillaFree(dict *d, void *val);
(1ULL << 0) /* Indicating that we should not update \
error stats after sending error reply */
/* networking.c -- Networking and Client related operations */
struct sockaddr_storage;
client *createClient(connection *conn);
void freeClient(client *c);
void freeClientAsync(client *c);
Expand All @@ -2637,7 +2642,7 @@ void setDeferredSetLen(client *c, void *node, long length);
void setDeferredAttributeLen(client *c, void *node, long length);
void setDeferredPushLen(client *c, void *node, long length);
int processInputBuffer(client *c);
void acceptCommonHandler(connection *conn, int flags, char *ip);
void acceptCommonHandler(connection *conn, int flags, char *ip, const struct sockaddr_storage *sa);
void readQueryFromClient(connection *conn);
int prepareClientToWrite(client *c);
void addReplyNull(client *c);
Expand Down Expand Up @@ -2736,6 +2741,13 @@ int authRequired(client *c);
void putClientInPendingWriteQueue(client *c);
client *createCachedResponseClient(void);
void deleteCachedResponseClient(client *recording_client);
void setTrustedNetworkFlag(client *c, const struct sockaddr_storage *sa);
int isUnixNetwork(client *c);
int checkConnFromTrustedNetwork(client *c);
int isTrustedNetwork(client *c);
in_addr_t getIPv4Netmask(in_addr_t ip);
int checkTrustedIP(in_addr_t ip);
void valkeySortIP(in_addr_t *IPlist, unsigned int IPcount);

/* logreqres.c - logging of requests and responses */
void reqresReset(client *c, int free_buf);
Expand Down
5 changes: 3 additions & 2 deletions src/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,14 @@ static void connSocketAcceptHandler(aeEventLoop *el, int fd, void *privdata, int
UNUSED(privdata);

while (max--) {
cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport);
struct sockaddr_storage sa;
cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport, &sa);
if (cfd == ANET_ERR) {
if (errno != EWOULDBLOCK) serverLog(LL_WARNING, "Accepting client connection: %s", server.neterr);
return;
}
serverLog(LL_VERBOSE, "Accepted %s:%d", cip, cport);
acceptCommonHandler(connCreateAcceptedSocket(cfd, NULL), 0, cip);
acceptCommonHandler(connCreateAcceptedSocket(cfd, NULL), 0, cip, &sa);
}
}

Expand Down
59 changes: 59 additions & 0 deletions src/trusted_network.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "server.h"
#include <sys/types.h>
#include <ifaddrs.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <sys/syscall.h>
#include <sys/socket.h>

int compareIP(const void *arg1, const void *arg2) {
in_addr_t ip1 = *(in_addr_t *)arg1;
in_addr_t ip2 = *(in_addr_t *)arg2;

if (ip1 == ip2)
return 0;
else if (ip1 < ip2)
return -1;
else
return 1;
}

void valkeySortIP(in_addr_t *IPlist, unsigned int IPcount) {
qsort(IPlist, IPcount, sizeof(IPlist[0]), compareIP);
}

int checkTrustedIP(in_addr_t ip) {
return bsearch(&ip, server.trustedIPList, server.trustedIPCount,
sizeof(server.trustedIPList[0]), compareIP) != NULL ? 1 : 0;
}

int isUnixNetwork(client *c) {
return c->flags & CLIENT_UNIX_SOCKET;
}


in_addr_t getIPv4Netmask(in_addr_t ip) {
struct ifaddrs *addrs = NULL;
in_addr_t netmask = 0;

if (getifaddrs(&addrs) == -1)
return 0;

for (struct ifaddrs *addr = addrs; addr != NULL; addr = addr->ifa_next) {
if (addr->ifa_addr == NULL || addr->ifa_netmask == NULL)
continue;

if (addr->ifa_addr->sa_family != AF_INET || addr->ifa_netmask->sa_family != AF_INET)
continue;

struct sockaddr_in *in_addr = (struct sockaddr_in *)addr->ifa_addr;
if (in_addr->sin_addr.s_addr == ip) {
struct sockaddr_in *mask = (struct sockaddr_in *)addr->ifa_netmask;
netmask = mask->sin_addr.s_addr;
break;
}
}

freeifaddrs(addrs);
return netmask;
}
2 changes: 1 addition & 1 deletion src/unix.c
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ static void connUnixAcceptHandler(aeEventLoop *el, int fd, void *privdata, int m
return;
}
serverLog(LL_VERBOSE, "Accepted connection to %s", server.unixsocket);
acceptCommonHandler(connCreateAcceptedUnix(cfd, NULL), CLIENT_UNIX_SOCKET, NULL);
acceptCommonHandler(connCreateAcceptedUnix(cfd, NULL), CLIENT_UNIX_SOCKET, NULL, NULL);
}
}

Expand Down
Loading

0 comments on commit 914a7c7

Please sign in to comment.