diff --git a/include/MySQL_Data_Stream.h b/include/MySQL_Data_Stream.h index 1ce5fed802..53eafe8ec3 100644 --- a/include/MySQL_Data_Stream.h +++ b/include/MySQL_Data_Stream.h @@ -5,6 +5,7 @@ #include "cpp.h" #include "MySQL_Protocol.h" +#include "proxy_protocol_info.h" #ifndef uchar typedef unsigned char uchar; @@ -140,6 +141,7 @@ class MySQL_Data_Stream char *addr; int port; } proxy_addr; + ProxyProtocolInfo * PROXY_info; unsigned int connect_tries; int query_retries_on_failure; diff --git a/include/MySQL_Thread.h b/include/MySQL_Thread.h index ddd663056d..b238915b48 100644 --- a/include/MySQL_Thread.h +++ b/include/MySQL_Thread.h @@ -460,6 +460,7 @@ class MySQL_Threads_Handler char *server_version; char *keep_multiplexing_variables; char *default_authentication_plugin; + char *proxy_protocol_networks; //unsigned int default_charset; // removed in 2.0.13 . Obsoleted previously using MySQL_Variables instead int handle_unknown_charset; int default_authentication_plugin_int; diff --git a/include/proxy_protocol_info.h b/include/proxy_protocol_info.h new file mode 100644 index 0000000000..9f35b2915b --- /dev/null +++ b/include/proxy_protocol_info.h @@ -0,0 +1,51 @@ +#ifndef PROXY_PROTOCOL_INFO_H +#define PROXY_PROTOCOL_INFO_H + +#include +#include +#include +#include + + +class ProxyProtocolInfo { +public: + char source_address[INET6_ADDRSTRLEN+1]; + char destination_address[INET6_ADDRSTRLEN+1]; + char proxy_address[INET6_ADDRSTRLEN+1]; + uint16_t source_port; + uint16_t destination_port; + uint16_t proxy_port; + + // Constructor (initializes to zeros) + ProxyProtocolInfo() { + memset(this, 0, sizeof(ProxyProtocolInfo)); + } + + // Copy constructor + ProxyProtocolInfo(const ProxyProtocolInfo& other) { + memcpy(this, &other, sizeof(ProxyProtocolInfo)); + } + + // Function to parse the PROXY protocol header (declared) + bool parseProxyProtocolHeader(const char* packet, size_t packet_length); + + bool is_in_network(const struct sockaddr* client_addr, const std::string& subnet_mask); + bool is_client_in_any_subnet(const struct sockaddr* client_addr, const char* subnet_list); + + // Copy method + ProxyProtocolInfo& copy(const ProxyProtocolInfo& other) { + if (this != &other) { + memcpy(this, &other, sizeof(ProxyProtocolInfo)); + } + return *this; + } +#ifdef DEBUG + sockaddr_in create_ipv4_addr(const std::string& ip); + sockaddr_in6 create_ipv6_addr(const std::string& ip); + void run_tests(); +#endif // DEBUG + bool is_valid_subnet_list(const char* subnet_list); + bool is_valid_subnet(const char* subnet); +}; + +#endif // PROXY_PROTOCOL_INFO_H diff --git a/include/proxysql_structs.h b/include/proxysql_structs.h index deac187de0..0762e50cab 100644 --- a/include/proxysql_structs.h +++ b/include/proxysql_structs.h @@ -782,6 +782,7 @@ __thread char *mysql_thread___default_schema; __thread char *mysql_thread___server_version; __thread char *mysql_thread___keep_multiplexing_variables; __thread char *mysql_thread___default_authentication_plugin; +__thread char *mysql_thread___proxy_protocol_networks; __thread char *mysql_thread___init_connect; __thread char *mysql_thread___ldap_user_variable; __thread char *mysql_thread___default_session_track_gtids; @@ -954,6 +955,7 @@ extern __thread char *mysql_thread___default_schema; extern __thread char *mysql_thread___server_version; extern __thread char *mysql_thread___keep_multiplexing_variables; extern __thread char *mysql_thread___default_authentication_plugin; +extern __thread char *mysql_thread___proxy_protocol_networks; extern __thread char *mysql_thread___init_connect; extern __thread char *mysql_thread___ldap_user_variable; extern __thread char *mysql_thread___default_session_track_gtids; diff --git a/lib/Makefile b/lib/Makefile index 325b5bd80c..88aeff654e 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -130,6 +130,7 @@ _OBJ_CXX := ProxySQL_GloVars.oo network.oo debug.oo configfile.oo Query_Cache.oo QP_rule_text.oo QP_query_digest_stats.oo \ GTID_Server_Data.oo MyHGC.oo MySrvConnList.oo MySrvList.oo MySrvC.oo \ MySQL_encode.oo MySQL_ResultSet.oo \ + proxy_protocol_info.oo \ proxysql_find_charset.oo ProxySQL_Poll.oo OBJ_CXX := $(patsubst %,$(ODIR)/%,$(_OBJ_CXX)) HEADERS := ../include/*.h ../include/*.hpp diff --git a/lib/MySQL_Thread.cpp b/lib/MySQL_Thread.cpp index 358a99cb1f..eaf60184b2 100644 --- a/lib/MySQL_Thread.cpp +++ b/lib/MySQL_Thread.cpp @@ -499,6 +499,7 @@ static char * mysql_thread_variables_names[]= { (char *)"data_packets_history_size", (char *)"handle_warnings", (char *)"evaluate_replication_lag_on_servers_load", + (char *)"proxy_protocol_networks", NULL }; @@ -1119,6 +1120,7 @@ MySQL_Threads_Handler::MySQL_Threads_Handler() { variables.ssl_p2s_crl=NULL; variables.ssl_p2s_crlpath=NULL; variables.keep_multiplexing_variables=strdup((char *)"tx_isolation,transaction_isolation,version"); + variables.proxy_protocol_networks = strdup((char *)""); variables.default_authentication_plugin=strdup((char *)"mysql_native_password"); variables.default_authentication_plugin_int = 0; // mysql_native_password #ifdef DEBUG @@ -1350,6 +1352,7 @@ char * MySQL_Threads_Handler::get_variable_string(char *name) { if (!strcmp(name,"interfaces")) return strdup(variables.interfaces); if (!strcmp(name,"keep_multiplexing_variables")) return strdup(variables.keep_multiplexing_variables); if (!strcmp(name,"default_authentication_plugin")) return strdup(variables.default_authentication_plugin); + if (!strcmp(name,"proxy_protocol_networks")) return strdup(variables.proxy_protocol_networks); // LCOV_EXCL_START proxy_error("Not existing variable: %s\n", name); assert(0); return NULL; @@ -1505,6 +1508,7 @@ char * MySQL_Threads_Handler::get_variable(char *name) { // this is the public f if (!strcasecmp(name,"default_schema")) return strdup(variables.default_schema); if (!strcasecmp(name,"keep_multiplexing_variables")) return strdup(variables.keep_multiplexing_variables); if (!strcasecmp(name,"default_authentication_plugin")) return strdup(variables.default_authentication_plugin); + if (!strcasecmp(name,"proxy_protocol_networks")) return strdup(variables.proxy_protocol_networks); if (!strcasecmp(name,"interfaces")) return strdup(variables.interfaces); if (!strcasecmp(name,"server_capabilities")) { // FIXME : make it human readable @@ -1878,6 +1882,28 @@ bool MySQL_Threads_Handler::set_variable(char *name, const char *value) { // thi return false; } } + if (!strcasecmp(name,"proxy_protocol_networks")) { + bool ret = false; + if (vallen == 0) { + // accept empty string + ret = true; + } else if ( (vallen == 1) && strcmp(value,"*")==0) { + // accept `*` + ret = true; + } else { + ProxyProtocolInfo ppi; + if (ppi.is_valid_subnet_list(value) == true) { + ret = true; + } + } + if (ret == true) { + free(variables.proxy_protocol_networks); + variables.proxy_protocol_networks=strdup(value); + return true; + } else { + return true; + } + } // SSL proxy to server variables if (!strcasecmp(name,"ssl_p2s_ca")) { if (variables.ssl_p2s_ca) free(variables.ssl_p2s_ca); @@ -2707,6 +2733,7 @@ MySQL_Threads_Handler::~MySQL_Threads_Handler() { if (variables.server_version) free(variables.server_version); if (variables.keep_multiplexing_variables) free(variables.keep_multiplexing_variables); if (variables.default_authentication_plugin) free(variables.default_authentication_plugin); + if (variables.proxy_protocol_networks) free(variables.proxy_protocol_networks); if (variables.firewall_whitelist_errormsg) free(variables.firewall_whitelist_errormsg); if (variables.init_connect) free(variables.init_connect); if (variables.ldap_user_variable) free(variables.ldap_user_variable); @@ -2838,6 +2865,7 @@ MySQL_Thread::~MySQL_Thread() { if (mysql_thread___server_version) { free(mysql_thread___server_version); mysql_thread___server_version=NULL; } if (mysql_thread___keep_multiplexing_variables) { free(mysql_thread___keep_multiplexing_variables); mysql_thread___keep_multiplexing_variables=NULL; } if (mysql_thread___default_authentication_plugin) { free(mysql_thread___default_authentication_plugin); mysql_thread___default_authentication_plugin=NULL; } + if (mysql_thread___proxy_protocol_networks) { free(mysql_thread___proxy_protocol_networks); mysql_thread___proxy_protocol_networks=NULL; } if (mysql_thread___firewall_whitelist_errormsg) { free(mysql_thread___firewall_whitelist_errormsg); mysql_thread___firewall_whitelist_errormsg=NULL; } if (mysql_thread___init_connect) { free(mysql_thread___init_connect); mysql_thread___init_connect=NULL; } if (mysql_thread___ldap_user_variable) { free(mysql_thread___ldap_user_variable); mysql_thread___ldap_user_variable=NULL; } @@ -4381,6 +4409,7 @@ void MySQL_Thread::refresh_variables() { GloMyLogger->audit_set_base_filename(); // both filename and filesize are set here REFRESH_VARIABLE_CHAR(default_schema); REFRESH_VARIABLE_CHAR(keep_multiplexing_variables); + REFRESH_VARIABLE_CHAR(proxy_protocol_networks); REFRESH_VARIABLE_CHAR(default_authentication_plugin); mysql_thread___default_authentication_plugin_int = GloMTH->variables.default_authentication_plugin_int; mysql_thread___server_capabilities=GloMTH->get_variable_uint16((char *)"server_capabilities"); @@ -5087,30 +5116,40 @@ SQLite3_result * MySQL_Threads_Handler::SQL3_Processlist() { } } - if (sess->mirror==false) { - switch (sess->client_myds->client_addr->sa_family) { - case AF_INET: { - struct sockaddr_in *ipv4 = (struct sockaddr_in *)sess->client_myds->client_addr; - inet_ntop(sess->client_myds->client_addr->sa_family, &ipv4->sin_addr, buf, INET_ADDRSTRLEN); - pta[4] = strdup(buf); - sprintf(port, "%d", ntohs(ipv4->sin_port)); - pta[5] = strdup(port); - break; - } - case AF_INET6: { - struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)sess->client_myds->client_addr; - inet_ntop(sess->client_myds->client_addr->sa_family, &ipv6->sin6_addr, buf, INET6_ADDRSTRLEN); - pta[4] = strdup(buf); - sprintf(port, "%d", ntohs(ipv6->sin6_port)); - pta[5] = strdup(port); - break; - } - default: - pta[4] = strdup("localhost"); - pta[5] = NULL; - break; - } - } else { + if (sess->mirror==false) { + switch (sess->client_myds->client_addr->sa_family) { + case AF_INET: + if (sess->client_myds->addr.addr != NULL) { + pta[4] = strdup(sess->client_myds->addr.addr); + sprintf(port, "%d", sess->client_myds->addr.port); + pta[5] = strdup(port); + } else { + struct sockaddr_in *ipv4 = (struct sockaddr_in *)sess->client_myds->client_addr; + inet_ntop(sess->client_myds->client_addr->sa_family, &ipv4->sin_addr, buf, INET_ADDRSTRLEN); + pta[4] = strdup(buf); + sprintf(port, "%d", ntohs(ipv4->sin_port)); + pta[5] = strdup(port); + } + break; + case AF_INET6: + if (sess->client_myds->addr.addr != NULL) { + pta[4] = strdup(sess->client_myds->addr.addr); + sprintf(port, "%d", sess->client_myds->addr.port); + pta[5] = strdup(port); + } else { + struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)sess->client_myds->client_addr; + inet_ntop(sess->client_myds->client_addr->sa_family, &ipv6->sin6_addr, buf, INET6_ADDRSTRLEN); + pta[4] = strdup(buf); + sprintf(port, "%d", ntohs(ipv6->sin6_port)); + pta[5] = strdup(port); + } + break; + default: + pta[4] = strdup("localhost"); + pta[5] = NULL; + break; + } + } else { pta[4] = strdup("mirror_internal"); pta[5] = NULL; } @@ -5126,28 +5165,28 @@ SQLite3_result * MySQL_Threads_Handler::SQL3_Processlist() { int rc; rc=getsockname(mc->fd, &addr, &addr_len); if (rc==0) { - switch (addr.sa_family) { - case AF_INET: { - struct sockaddr_in *ipv4 = (struct sockaddr_in *)&addr; - inet_ntop(addr.sa_family, &ipv4->sin_addr, buf, INET_ADDRSTRLEN); - pta[7] = strdup(buf); - sprintf(port, "%d", ntohs(ipv4->sin_port)); - pta[8] = strdup(port); - break; - } - case AF_INET6: { - struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)&addr; - inet_ntop(addr.sa_family, &ipv6->sin6_addr, buf, INET6_ADDRSTRLEN); - pta[7] = strdup(buf); - sprintf(port, "%d", ntohs(ipv6->sin6_port)); - pta[8] = strdup(port); - break; - } - default: - pta[7] = strdup("localhost"); - pta[8] = NULL; - break; - } + switch (addr.sa_family) { + case AF_INET: { + struct sockaddr_in *ipv4 = (struct sockaddr_in *)&addr; + inet_ntop(addr.sa_family, &ipv4->sin_addr, buf, INET_ADDRSTRLEN); + pta[7] = strdup(buf); + sprintf(port, "%d", ntohs(ipv4->sin_port)); + pta[8] = strdup(port); + break; + } + case AF_INET6: { + struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)&addr; + inet_ntop(addr.sa_family, &ipv6->sin6_addr, buf, INET6_ADDRSTRLEN); + pta[7] = strdup(buf); + sprintf(port, "%d", ntohs(ipv6->sin6_port)); + pta[8] = strdup(port); + break; + } + default: + pta[7] = strdup("localhost"); + pta[8] = NULL; + break; + } } else { pta[7]=NULL; pta[8]=NULL; diff --git a/lib/mysql_data_stream.cpp b/lib/mysql_data_stream.cpp index a45e77dc93..a43960af4a 100644 --- a/lib/mysql_data_stream.cpp +++ b/lib/mysql_data_stream.cpp @@ -307,6 +307,8 @@ MySQL_Data_Stream::MySQL_Data_Stream() { proxy_addr.addr=NULL; proxy_addr.port=0; + PROXY_info = NULL; + sess=NULL; mysql_real_query.pkt.ptr=NULL; mysql_real_query.pkt.size=0; @@ -380,6 +382,10 @@ MySQL_Data_Stream::~MySQL_Data_Stream() { free(proxy_addr.addr); proxy_addr.addr=NULL; } + if (PROXY_info) { + delete PROXY_info; + PROXY_info = NULL; + } free_mysql_real_query(); @@ -1055,6 +1061,92 @@ int MySQL_Data_Stream::buffer2array() { } else { if ((queueIN.pkt.size==0) && queue_data(queueIN)>=sizeof(mysql_hdr)) { + // check if this is a PROXY protocol packet + if ( + pkts_recv==0 && // checks if no packets have been received yet + queueIN.tail == 0 && // checks if the input queue (`queueIN`) was never rotated . This check is redundant + queueIN.head > 7 && // ensures that there are at least 8 bytes in the input buffer (`queueIN.buffer`) + // This is because the PROXY protocol signature (`PROXY`) is 5 bytes long, and we need at least 3 more bytes to check for the `\r\n` delimiter. + strncmp((char *)queueIN.buffer,"PROXY ",6) == 0 // checks if the first 6 bytes of the buffer match the "PROXY " string, indicating a potential PROXY protocol packet + ) { + bool found_delimiter = false; + size_t b = 0; + const char *ptr = (char *)queueIN.buffer; + // This loop iterates through the buffer, starting from the 8th byte (index 7) until the end of the buffer (index `queueIN.head - 1`). + // The loop continues as long as the delimiter hasn't been found (`found_delimiter == false`) + // the loop looks for \r\n , the delimiter of the PROXY packet + for (size_t i = 7; found_delimiter == false && i < queueIN.head - 1; i++) { + if ( + ptr[i] == '\r' + && + ptr[i+1] == '\n' + ) { + found_delimiter = true; + b = i+2; + } + } + if (found_delimiter) { +/* + // we could return a packet, but it is actually better to handle it here + queueIN.pkt.size = b; + queueIN.pkt.ptr=l_alloc(queueIN.pkt.size); + memcpy(queueIN.pkt.ptr, queueIN.buffer, b); + PSarrayIN->add(queueIN.pkt.ptr,queueIN.pkt.size); + add_to_data_packet_history(data_packets_history_IN,queueIN.pkt.ptr,queueIN.pkt.size); +*/ + // we move forward the internal pointer. + // note that parseProxyProtocolHeader() will read from the beginning of the buffer + queue_r(queueIN, b); + + bool accept_proxy = false; // by default, we do not accept a PROXY header + const char * proxy_protocol_networks = mysql_thread___proxy_protocol_networks; + + ProxyProtocolInfo ppi; + if (strcmp(proxy_protocol_networks,"*") == 0) { // all networks are accepted + accept_proxy = true; + } else { + if (client_addr) { + if (ppi.is_client_in_any_subnet(client_addr, proxy_protocol_networks) == true) { + accept_proxy = true; + } + } + } + if (accept_proxy == true) { + if (ppi.parseProxyProtocolHeader((const char *)queueIN.buffer, b)) { + PROXY_info = new ProxyProtocolInfo(ppi); + // we take a copy of old address/port + if (addr.addr) { + strncpy(PROXY_info->proxy_address, addr.addr, INET6_ADDRSTRLEN); + free(addr.addr); + } + PROXY_info->proxy_port = addr.port; + // we override old address/port + addr.addr = strdup(PROXY_info->source_address); + addr.port = PROXY_info->source_port; + } else { + if (addr.addr) { + proxy_warning("Unable to parse PROXY header from IP %s . Skipping PROXY header\n", addr.addr); + } + } + } else { // the PROXY header was not accepted + if (addr.addr) { + proxy_warning("Skipping PROXY header from IP %s because not matching mysql-proxy_protocol_networks. Skipping PROXY header\n", addr.addr); + } + } + + + pkts_recv++; + queueIN.pkt.size=0; + queueIN.pkt.ptr=NULL; + return b; + } else { + // set the connection unhealthy , this will cause the session to be destroyed + if (sess) { + sess->set_unhealthy(); + } + } + return 0; // we always return + } proxy_debug(PROXY_DEBUG_PKT_ARRAY, 5, "Session=%p . Reading the header of a new packet\n", sess); memcpy(&queueIN.hdr,queue_r_ptr(queueIN),sizeof(mysql_hdr)); pkt_sid=queueIN.hdr.pkt_id; @@ -1550,6 +1642,14 @@ void MySQL_Data_Stream::get_client_myds_info_json(json& j) { jc1["client_addr"]["port"] = addr.port; jc1["proxy_addr"]["address"] = ( proxy_addr.addr ? proxy_addr.addr : "" ); jc1["proxy_addr"]["port"] = proxy_addr.port; + if (PROXY_info != NULL) { + jc1["PROXY_V1"]["source_address"] = PROXY_info->source_address; + jc1["PROXY_V1"]["destination_address"] = PROXY_info->destination_address; + jc1["PROXY_V1"]["proxy_address"] = PROXY_info->proxy_address; + jc1["PROXY_V1"]["source_port"] = PROXY_info->source_port; + jc1["PROXY_V1"]["destination_port"] = PROXY_info->destination_port; + jc1["PROXY_V1"]["proxy_port"] = PROXY_info->proxy_port; + } jc1["encrypted"] = encrypted; if (encrypted) { const SSL_CIPHER *cipher = SSL_get_current_cipher(ssl); diff --git a/lib/proxy_protocol_info.cpp b/lib/proxy_protocol_info.cpp new file mode 100644 index 0000000000..522463f643 --- /dev/null +++ b/lib/proxy_protocol_info.cpp @@ -0,0 +1,382 @@ +#include "proxy_protocol_info.h" +#include +#include +#include +#include + +static bool DEBUG_ProxyProtocolInfo = false; + +// Function to parse the PROXY protocol header +bool ProxyProtocolInfo::parseProxyProtocolHeader(const char* packet, size_t packet_length) { + // Check for minimum header length (including CRLF) + if (packet_length < 15) { + return false; // Not a valid PROXY protocol header + } + + // Create a temporary buffer on the stack + char temp_buffer[packet_length + 1]; + + // Copy the packet data + memcpy(temp_buffer, packet, packet_length); + temp_buffer[packet_length] = '\0'; // Null-terminate the buffer + + + // Verify the PROXY protocol signature + if (memcmp(temp_buffer, "PROXY", 5) != 0) { + return false; // Not a valid PROXY protocol header + } + + // Check for the space after "PROXY" + if (temp_buffer[5] != ' ') { + return false; // Invalid header format + } + + // Check for the protocol type + if (memcmp(temp_buffer + 6, "TCP4", 4) == 0 || + memcmp(temp_buffer + 6, "TCP6", 4) == 0 || + memcmp(temp_buffer + 6, "UNKNOWN", 7) == 0) { + + // Parse the header using sscanf + int result = sscanf(temp_buffer, "PROXY %*s %s %s %hu %hu\r\n", + source_address, destination_address, + &source_port, &destination_port); + + // Check if sscanf successfully parsed all fields + if (result == 4) { + return true; // Successful parsing + } else { + // Handle partial parsing or invalid format + return false; // Indicate an error + } + } + + return false; // Invalid header format +} + +/** + * Checks if a client address is within a specified subnet. + * + * @param client_addr Pointer to the client's sockaddr structure (either sockaddr_in or sockaddr_in6). + * @param subnet_mask The subnet in CIDR notation (e.g., "192.168.1.0/24" for IPv4 or "2001:db8::/32" for IPv6). + * @return True if the client address is within the specified subnet, otherwise false. + */ +bool ProxyProtocolInfo::is_in_network(const struct sockaddr* client_addr, const std::string& subnet_mask) { + // Determine address family (IPv4 or IPv6) + int family = client_addr->sa_family; + + // Parse the subnet and mask + union { + struct in_addr v4; + struct in6_addr v6; + } subnet_addr; + + uint8_t mask = 0; + char addr_str[INET6_ADDRSTRLEN]; + + if (family == AF_INET) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Parsing IPv4 subnet mask" << std::endl; + // Parse the IPv4 subnet mask using sscanf + if (sscanf(subnet_mask.c_str(), "%[^/]/%hhu", addr_str, &mask) != 2) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Invalid subnet/mask format" << std::endl; + return false; // Invalid subnet/mask format + } + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Subnet: " << addr_str << ", Mask: " << (int)mask << std::endl; + // Convert the parsed subnet address to binary format + if (inet_pton(AF_INET, addr_str, &subnet_addr.v4) != 1) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Invalid IPv4 address" << std::endl; + return false; // Invalid IPv4 address + } + } else if (family == AF_INET6) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Parsing IPv6 subnet mask" << std::endl; + // Parse the IPv6 subnet mask using sscanf + if (sscanf(subnet_mask.c_str(), "%[^/]/%hhu", addr_str, &mask) != 2) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Invalid subnet/mask format" << std::endl; + return false; // Invalid subnet/mask format + } + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Subnet: " << addr_str << ", Mask: " << (int)mask << std::endl; + // Convert the parsed subnet address to binary format + if (inet_pton(AF_INET6, addr_str, &subnet_addr.v6) != 1) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Invalid IPv6 address" << std::endl; + return false; // Invalid IPv6 address + } + } else { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Unsupported address family" << std::endl; + return false; // Unsupported address family + } + + uint8_t network_addr[16] = {0}; + if (family == AF_INET) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Calculating network address for IPv4" << std::endl; + // Calculate the network address for IPv4 + uint32_t subnet = ntohl(subnet_addr.v4.s_addr) & (0xFFFFFFFF << (32 - mask)); + subnet = htonl(subnet); + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Subnet address (masked): " << inet_ntoa(*(struct in_addr*)&subnet) << std::endl; + // Copy the masked subnet address into the network_addr array + memcpy(network_addr, &subnet, sizeof(subnet)); + } else if (family == AF_INET6) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Calculating network address for IPv6" << std::endl; + // Calculate the network address for IPv6 + uint8_t* addr = subnet_addr.v6.s6_addr; + int bits_left = mask; + for (int i = 0; i < 16; ++i) { + if (bits_left >= 8) { + network_addr[i] = addr[i]; + bits_left -= 8; + } else if (bits_left > 0) { + network_addr[i] = addr[i] & (0xFF << (8 - bits_left)); + bits_left = 0; + } else { + network_addr[i] = 0; + } + } + if (DEBUG_ProxyProtocolInfo==true) { + char network_addr_str[INET6_ADDRSTRLEN]; + inet_ntop(AF_INET6, network_addr, network_addr_str, INET6_ADDRSTRLEN); + std::cout << "Subnet address (masked): " << network_addr_str << std::endl; + } + } + + uint8_t client_addr_int[16] = {0}; + if (family == AF_INET) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Extracting client address for IPv4" << std::endl; + // Extract the client address for IPv4 + uint32_t client = ntohl(((struct sockaddr_in*)client_addr)->sin_addr.s_addr); + client = htonl(client); + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Client address: " << inet_ntoa(*(struct in_addr*)&client) << std::endl; + // Copy the client address into the client_addr_int array + memcpy(client_addr_int, &client, sizeof(client)); + } else if (family == AF_INET6) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Extracting client address for IPv6" << std::endl; + // Copy the client address into the client_addr_int array + memcpy(client_addr_int, ((struct sockaddr_in6*)client_addr)->sin6_addr.s6_addr, 16); + if (DEBUG_ProxyProtocolInfo==true) { + char client_addr_str[INET6_ADDRSTRLEN]; + inet_ntop(AF_INET6, client_addr_int, client_addr_str, INET6_ADDRSTRLEN); + std::cout << "Client address: " << client_addr_str << std::endl; + } + } + + // Calculate the number of bytes to compare based on the mask + int bytes_to_compare = mask / 8; + int remaining_bits = mask % 8; + + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Comparing full bytes covered by the mask" << std::endl; + // Compare the full bytes covered by the mask + if (memcmp(network_addr, client_addr_int, bytes_to_compare) != 0) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Address does not match in full byte comparison" << std::endl; + return false; + } + + if (remaining_bits > 0) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Comparing remaining bits" << std::endl; + // Compare the remaining bits covered by the mask + uint8_t mask_byte = 0xFF << (8 - remaining_bits); + if ((network_addr[bytes_to_compare] & mask_byte) != (client_addr_int[bytes_to_compare] & mask_byte)) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Address does not match in remaining bits comparison" << std::endl; + return false; // Addresses don't match in remaining bits comparison + } + } + + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Client address is within the subnet" << std::endl; + return true; // Client address is within the subnet +} + +bool ProxyProtocolInfo::is_client_in_any_subnet(const struct sockaddr* client_addr, const char* subnet_list) { + // Create a copy of the subnet list to avoid modifying the original string + char* subnet_list_copy = new char[strlen(subnet_list) + 1]; + strcpy(subnet_list_copy, subnet_list); + + char* token = strtok(subnet_list_copy, ","); // Get the first subnet + while (token != NULL) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Checking subnet: " << token << std::endl; + if (is_in_network(client_addr, token)) { + if (DEBUG_ProxyProtocolInfo==true) + std::cout << "Client is in subnet: " << token << std::endl; + delete[] subnet_list_copy; // Deallocate the copy + return true; // Client is in at least one subnet + } + token = strtok(NULL, ","); // Get the next subnet + } + delete[] subnet_list_copy; // Deallocate the copy + return false; // Client is not in any of the subnets +} + +#ifdef DEBUG + +// Helper function to create an IPv4 sockaddr structure +sockaddr_in ProxyProtocolInfo::create_ipv4_addr(const std::string& ip) { + sockaddr_in addr; + addr.sin_family = AF_INET; + inet_pton(AF_INET, ip.c_str(), &addr.sin_addr); + return addr; +} + +// Helper function to create an IPv6 sockaddr structure +sockaddr_in6 ProxyProtocolInfo::create_ipv6_addr(const std::string& ip) { + sockaddr_in6 addr; + addr.sin6_family = AF_INET6; + inet_pton(AF_INET6, ip.c_str(), &addr.sin6_addr); + return addr; +} + +// Test cases for the is_in_network function +void ProxyProtocolInfo::run_tests() { + // IPv4 Tests + { + sockaddr_in client_addr = create_ipv4_addr("192.168.1.10"); + assert(is_in_network((sockaddr*)&client_addr, "192.168.1.0/24") == true); + assert(is_in_network((sockaddr*)&client_addr, "192.168.1.0/25") == true); + assert(is_in_network((sockaddr*)&client_addr, "192.168.1.0/26") == true); + assert(is_in_network((sockaddr*)&client_addr, "192.168.2.0/24") == false); + assert(is_in_network((sockaddr*)&client_addr, "192.168.0.0/16") == true); + assert(is_in_network((sockaddr*)&client_addr, "192.168.1.10/32") == true); + assert(is_in_network((sockaddr*)&client_addr, "192.168.1.11/32") == false); + } + + // IPv6 Tests + { + sockaddr_in6 client_addr = create_ipv6_addr("2001:db8::1"); + assert(is_in_network((sockaddr*)&client_addr, "2001:db8::/32") == true); + assert(is_in_network((sockaddr*)&client_addr, "2001:db8::/48") == true); + assert(is_in_network((sockaddr*)&client_addr, "2001:db8::/64") == true); + assert(is_in_network((sockaddr*)&client_addr, "2001:db8:0:1::/64") == false); + assert(is_in_network((sockaddr*)&client_addr, "2001:db8::1/128") == true); + assert(is_in_network((sockaddr*)&client_addr, "2001:db8::2/128") == false); + } + { + sockaddr_in6 client_addr = create_ipv6_addr("2001:db8:0:1::1"); + assert(is_in_network((sockaddr*)&client_addr, "2001:db8:0:1::/64") == true); + assert(is_in_network((sockaddr*)&client_addr, "2001:db8::/32") == true); + assert(is_in_network((sockaddr*)&client_addr, "2001:db8:0:2::/64") == false); + assert(is_in_network((sockaddr*)&client_addr, "2001:db8::1/128") == false); + assert(is_in_network((sockaddr*)&client_addr, "2001:db8:0:1::1/128") == true); + } + { + struct sockaddr_in client_addr = create_ipv4_addr("172.16.14.1"); + assert(is_client_in_any_subnet((sockaddr*)&client_addr, "172.16.0.0/16,192.168.1.0/24") == true); + assert(is_client_in_any_subnet((sockaddr*)&client_addr, "172.17.0.0/16,192.168.1.0/24") == false); + assert(is_client_in_any_subnet((sockaddr*)&client_addr, "2001:db8:0:1::/64,172.16.0.0/16,192.168.1.0/24") == true); + assert(is_client_in_any_subnet((sockaddr*)&client_addr, "2001:db8:0:1::/64,172.17.0.0/16,192.168.1.0/24") == false); + } + { + sockaddr_in6 client_addr = create_ipv6_addr("2001:db8:0:1::1"); + assert(is_client_in_any_subnet((sockaddr*)&client_addr, "2001:db8:0:1::/64,2001:db8:0:2::/64") == true); + assert(is_client_in_any_subnet((sockaddr*)&client_addr, "2001:db8:0:2::/64,2001:db8:0:1::/64") == true); + assert(is_client_in_any_subnet((sockaddr*)&client_addr, "2001:db8:0:1::/64,172.16.0.0/16") == true); + assert(is_client_in_any_subnet((sockaddr*)&client_addr, "172.16.0.0/16,2001:db8:0:1::/64") == true); + assert(is_client_in_any_subnet((sockaddr*)&client_addr, "2001:db8:0:2::/64,172.16.0.0/16") == false); + assert(is_client_in_any_subnet((sockaddr*)&client_addr, "172.16.0.0/16,2001:db8:0:2::/64") == false); + } + { + const char* subnet_list1 = "192.168.1.0/24,10.0.0.0/8,2001:0:200::/32"; + const char* subnet_list2 = "192.168.1.0/24,10.0.0.0/not_a_mask,2001:0:200::/32"; + const char* subnet_list3 = "192.168.1.0/24,invalid_ipv4,2001:0:200::/32"; + const char* subnet_list4 = ""; + + assert(is_valid_subnet_list(subnet_list1) == true); + assert(is_valid_subnet_list(subnet_list2) == false); + assert(is_valid_subnet_list(subnet_list3) == false); + assert(is_valid_subnet_list(subnet_list4) == false); + } +} + +#endif // DEBUG + +bool ProxyProtocolInfo::is_valid_subnet_list(const char* subnet_list) { + // Check if the string is empty + if (subnet_list == nullptr || *subnet_list == '\0') { + return false; // Empty string is not a valid subnet list + } + + // Create a copy of the string to avoid modifying the original + char* subnet_list_copy = new char[strlen(subnet_list) + 1]; + strcpy(subnet_list_copy, subnet_list); + + // Tokenize the string using ',' as the delimiter + char* token = strtok(subnet_list_copy, ","); + while (token != NULL) { + // Check if the token is a valid subnet + if (!is_valid_subnet(token)) { + delete[] subnet_list_copy; // Deallocate the copy + return false; // Invalid subnet found + } + token = strtok(NULL, ","); // Get the next token + } + + delete[] subnet_list_copy; // Deallocate the copy + return true; // All subnets are valid +} + + +// Helper function to verify a single subnet +bool ProxyProtocolInfo::is_valid_subnet(const char* subnet) { + // Check if the subnet is empty + if (subnet == NULL || *subnet == '\0') { + return false; // Empty subnet is not valid + } + + // Check if the subnet contains a '/' character (CIDR notation) + if (strchr(subnet, '/') == NULL) { + return false; // Missing '/' character in subnet + } + + // Check if the subnet is a valid IPv4 or IPv6 address + int family = AF_INET; // Default to IPv4 + if (strchr(subnet, ':') != NULL) { + family = AF_INET6; // IPv6 if a colon is found + } + + char addr_str[INET6_ADDRSTRLEN]; + uint8_t mask = 0; + + if (family == AF_INET) { + // Parse IPv4 subnet using sscanf + if (sscanf(subnet, "%[^/]/%hhu", addr_str, &mask) != 2) { + return false; // Invalid IPv4 subnet format + } + } else if (family == AF_INET6) { + // Parse IPv6 subnet using sscanf + if (sscanf(subnet, "%[^/]/%hhu", addr_str, &mask) != 2) { + return false; // Invalid IPv6 subnet format + } + } else { + return false; // Unsupported address family + } + + // Validate the mask value + if (mask < 0 || mask > 128) { + return false; // Invalid mask value + } + + // Check if the address is valid using inet_pton + union { + struct in_addr v4; + struct in6_addr v6; + } addr; // Create a union to hold both IPv4 and IPv6 addresses + if (inet_pton(family, addr_str, &addr) != 1) { + return false; // Invalid IP address + } + + return true; // Valid subnet +} diff --git a/src/main.cpp b/src/main.cpp index 62c47c33e4..382b42363a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -42,6 +42,11 @@ #include +#ifdef DEBUG +#include "proxy_protocol_info.h" +#endif // DEBUG + + /* extern "C" MySQL_LDAP_Authentication * create_MySQL_LDAP_Authentication_func() { return NULL; @@ -1969,6 +1974,17 @@ int main(int argc, const char * argv[]) { if (rc) { exit(EXIT_FAILURE); } } + +#ifdef DEBUG + { + // This run some ProxyProtocolInfo tests. + // It will assert() if any test fails + ProxyProtocolInfo ppi; + ppi.run_tests(); + } +#endif // DEBUG + + { MYSQL *my = mysql_init(NULL); mysql_close(my); diff --git a/test/tap/tests/test_PROXY_Protocol-t.cpp b/test/tap/tests/test_PROXY_Protocol-t.cpp new file mode 100644 index 0000000000..14919b0d49 --- /dev/null +++ b/test/tap/tests/test_PROXY_Protocol-t.cpp @@ -0,0 +1,147 @@ +/** + * @file test_PROXY_Protocol-t.cpp + * @brief This test tries the PROXY protocol + * @details The test performs authentication using the PROXY protocol , then + * verifies PROXYSQL INTERNAL SESSION + * @date 2024-08-07 + */ + +#include +#include +#include +#include "mysql.h" + +#include "tap.h" +#include "command_line.h" +#include "utils.h" +#include "json.hpp" + +#include // For std::pair + +using std::string; +using namespace nlohmann; + +void parse_result_json_column(MYSQL_RES *result, json& j) { + if(!result) return; + MYSQL_ROW row; + + while ((row = mysql_fetch_row(result))) { + j = json::parse(row[0]); + } +} + +int connect_and_run_query(CommandLine& cl, int tests, const char *hdr) { + int ret = 0; // number of success + MYSQL* proxysql_mysql = mysql_init(NULL); + + mysql_optionsv(proxysql_mysql, MARIADB_OPT_PROXY_HEADER, hdr, strlen(hdr)); + + if (!mysql_real_connect(proxysql_mysql, cl.host, cl.username, cl.password, NULL, cl.port, NULL, 0)) { + fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(proxysql_mysql)); + return ret; + } else { + ok(true, "Successfully connected"); + ret++; + } + MYSQL_QUERY(proxysql_mysql, "PROXYSQL INTERNAL SESSION"); + json j_status {}; + MYSQL_RES* int_session_res = mysql_store_result(proxysql_mysql); + parse_result_json_column(int_session_res, j_status); + mysql_free_result(int_session_res); + bool proxy_info_found = false; + + //diag("%s",j_status.dump(1).c_str()); + + json jv1 {}; + if (j_status.find("client") != j_status.end()) { + json& j = *j_status.find("client"); + if (j.find("PROXY_V1") != j.end()) { + proxy_info_found = true; + jv1 = *j.find("PROXY_V1"); + } + } + if (tests == 2) { // we must found PROXY_V1 + ok(proxy_info_found == true, "PROXY_V1 %sfound", proxy_info_found ? "" : "not "); + if (proxy_info_found == true) { + ret++; + diag("%s",jv1.dump().c_str()); + } + } else if (tests == 1) { // PROXY_V1 should not be present + ok(proxy_info_found == false, "PROXY_V1 %sfound", proxy_info_found ? "" : "not "); + if (proxy_info_found == true) { + diag("%s",jv1.dump().c_str()); + } else { + ret++; + } + } else { + exit(exit_status()); + } + mysql_close(proxysql_mysql); + return ret; +} + +int main(int argc, char** argv) { + CommandLine cl; + + std::vector> Headers; + Headers.push_back(std::make_pair(2, "PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n")); + Headers.push_back(std::make_pair(1, "PROXY TCP4 192.168.0.1 192.168.0.11 56324\r\n")); + Headers.push_back(std::make_pair(0, "PROXY TCP4 192.168.0.1 192.168.0.11 56324 443")); + Headers.push_back(std::make_pair(0, "PROXY")); + Headers.push_back(std::make_pair(2, "PROXY TCP6 fe80::d6ae:52ff:fecf:9876 fe80::d6ae:52aa:fecf:1234 56324 443\r\n")); + Headers.push_back(std::make_pair(1, "PROXY TCP6 fe80::d6ae:52ff:fecf:9876 fe80::d6ae:52aa:fecf:1234 56324\r\n")); + Headers.push_back(std::make_pair(0, "PROXY TCP6 fe80::d6ae:52ff:fecf:9876 fe80::d6ae:52aa:fecf:1234 56324 443")); + + int p = 0; + // we will run the tests twice, with: + // - with mysql-proxy_protocol_networks='' + p += Headers.size(); + for (const auto& pair : Headers) { + p += ( pair.first ? 2 : 0); // PROXY_V1 should not be present + } + // - with mysql-proxy_protocol_networks='*' + p += Headers.size(); + for (const auto& pair : Headers) { + p += ( pair.first ? 2 : 0); // perform either 2 checks, or 0 + } + plan(p); + + MYSQL* proxysql_admin = mysql_init(NULL); + // Initialize connections + if (!proxysql_admin) { + fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(proxysql_admin)); + return -1; + } + + if (!mysql_real_connect(proxysql_admin, cl.host, cl.admin_username, cl.admin_password, NULL, cl.admin_port, NULL, 0)) { + fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(proxysql_admin)); + return -1; + } + + diag("Setting mysql-proxy_protocol_networks=''"); + MYSQL_QUERY(proxysql_admin, "SET mysql-proxy_protocol_networks=''"); + MYSQL_QUERY(proxysql_admin, "LOAD MYSQL VARIABLES TO RUNTIME"); + + for (const auto& pair : Headers) { + const std::string& hdr = pair.second; + diag("Testing connection with header: %s", hdr.c_str()); + int arg1 = pair.first ? 1 : 0; // if pair.first is not 0 , we will pass 1 because PROXY_V1 should not be present + int ret = connect_and_run_query(cl, arg1, hdr.c_str()); + int expected = pair.first ? 2 : 0; + ok(ret == expected , "Expected successes: %d , returned successes: %d", expected, ret); + } + + diag("Setting mysql-proxy_protocol_networks='*'"); + MYSQL_QUERY(proxysql_admin, "SET mysql-proxy_protocol_networks='*'"); + MYSQL_QUERY(proxysql_admin, "LOAD MYSQL VARIABLES TO RUNTIME"); + + for (const auto& pair : Headers) { + const std::string& hdr = pair.second; + diag("Testing connection with header: %s", hdr.c_str()); + int ret = connect_and_run_query(cl, pair.first, hdr.c_str()); + int expected = pair.first ? 2 : 0; + ok(ret == expected , "Expected successes: %d , returned successes: %d", expected, ret); + } + + return exit_status(); +}