Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 57 additions & 13 deletions client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@
#include <limits.h>
#endif

#ifdef HAVE_ASSERT_H
#include <assert.h>
#endif

#include <assert.h>
#include <limits.h>
#include <math.h>
#include <algorithm>
#include <arpa/inet.h>
Expand All @@ -59,7 +58,14 @@ bool client::setup_client(benchmark_config *config, abstract_protocol *protocol,
unsigned long long total_num_of_clients = config->clients*config->threads;

// create main connection
shard_connection* conn = new shard_connection(m_connections.size(), this, m_config, m_event_base, protocol);
unsigned int thread_id = 0; // TODO: set actual thread id if available
unsigned int client_index = m_connections.size();
unsigned int num_clients_per_thread = config->clients;
unsigned int conn_id = thread_id * num_clients_per_thread + client_index;
shard_connection* conn = new shard_connection(
client_index, this, m_config, m_event_base, protocol,
conn_id
);
m_connections.push_back(conn);

m_obj_gen = objgen->clone();
Expand Down Expand Up @@ -99,7 +105,7 @@ bool client::setup_client(benchmark_config *config, abstract_protocol *protocol,
return true;
}

client::client(client_group* group) :
client::client(client_group* group, unsigned int conn_id) :
m_event_base(NULL), m_initialized(false), m_end_set(false), m_config(NULL),
m_obj_gen(NULL), m_stats(group->get_config()), m_reqs_processed(0), m_reqs_generated(0),
m_set_ratio_count(0), m_get_ratio_count(0),
Expand All @@ -108,16 +114,21 @@ client::client(client_group* group) :
{
m_event_base = group->get_event_base();

// Initialize conn_id string and value with prefix
m_conn_id_str = "user" + std::to_string(conn_id);
m_conn_id_value = m_conn_id_str.c_str();
m_conn_id_value_len = m_conn_id_str.length();

if (!setup_client(group->get_config(), group->get_protocol(), group->get_obj_gen())) {
return;
}

benchmark_debug_log("new client %p successfully set up.\n", this);
benchmark_debug_log("new client %p successfully set up with conn_id: %s.\n", this, m_conn_id_value);
m_initialized = true;
}

client::client(struct event_base *event_base, benchmark_config *config,
abstract_protocol *protocol, object_generator *obj_gen) :
abstract_protocol *protocol, object_generator *obj_gen, unsigned int conn_id) :
m_event_base(NULL), m_initialized(false), m_end_set(false), m_config(NULL),
m_obj_gen(NULL), m_stats(config), m_reqs_processed(0), m_reqs_generated(0),
m_set_ratio_count(0), m_get_ratio_count(0),
Expand All @@ -126,11 +137,16 @@ client::client(struct event_base *event_base, benchmark_config *config,
{
m_event_base = event_base;

// Initialize conn_id string and value
m_conn_id_str = std::to_string(conn_id);
m_conn_id_value = m_conn_id_str.c_str();
m_conn_id_value_len = m_conn_id_str.length();

if (!setup_client(config, protocol, obj_gen)) {
return;
}

benchmark_debug_log("new client %p successfully set up.\n", this);
benchmark_debug_log("new client %p successfully set up with conn_id: %s.\n", this, m_conn_id_value);
m_initialized = true;
}

Expand Down Expand Up @@ -273,7 +289,11 @@ bool client::create_arbitrary_request(unsigned int command_index, struct timeval

const arbitrary_command& cmd = get_arbitrary_command(command_index);

benchmark_debug_log("%s: %s:\n", m_connections[conn_id]->get_readable_id(), cmd.command.c_str());
benchmark_debug_log("%s: %s", m_connections[conn_id]->get_readable_id(), cmd.command.c_str());

// Build final command string for debug output
std::string final_command = cmd.command;
bool has_substitutions = false;

for (unsigned int i = 0; i < cmd.command_args.size(); i++) {
const command_arg* arg = &cmd.command_args[i];
Expand All @@ -293,9 +313,32 @@ bool client::create_arbitrary_request(unsigned int command_index, struct timeval
assert(value_len > 0);

cmd_size += m_connections[conn_id]->send_arbitrary_command(arg, value, value_len);
} else if (arg->type == conn_id_type) {
// Replace __conn_id__ placeholder with actual connection ID
std::string substituted_arg = arg->data;
size_t pos = substituted_arg.find(CONN_PLACEHOLDER);
if (pos != std::string::npos) {
substituted_arg.replace(pos, strlen(CONN_PLACEHOLDER), m_conn_id_value);
has_substitutions = true;
}

cmd_size += m_connections[conn_id]->send_arbitrary_command(arg, substituted_arg.c_str(), substituted_arg.length());

// Replace placeholder in final command string for debug output
pos = final_command.find(CONN_PLACEHOLDER);
if (pos != std::string::npos) {
final_command.replace(pos, strlen(CONN_PLACEHOLDER), m_conn_id_value);
}
}
}

// Show final command if substitutions were made
if (has_substitutions) {
benchmark_debug_log(" -> %s\n", final_command.c_str());
} else {
benchmark_debug_log("\n");
}

m_connections[conn_id]->send_arbitrary_command_end(command_index, &timestamp, cmd_size);
return true;
}
Expand Down Expand Up @@ -581,8 +624,8 @@ bool verify_client::finished(void)

///////////////////////////////////////////////////////////////////////////

client_group::client_group(benchmark_config* config, abstract_protocol *protocol, object_generator* obj_gen) :
m_base(NULL), m_config(config), m_protocol(protocol), m_obj_gen(obj_gen)
client_group::client_group(benchmark_config* config, abstract_protocol *protocol, object_generator* obj_gen, unsigned int thread_id) :
m_base(NULL), m_config(config), m_protocol(protocol), m_obj_gen(obj_gen), m_thread_id(thread_id)
{
m_base = event_base_new();
assert(m_base != NULL);
Expand All @@ -608,11 +651,12 @@ int client_group::create_clients(int num)
{
for (int i = 0; i < num; i++) {
client* c;
unsigned int conn_id = m_thread_id * num + i + 1;

if (m_config->cluster_mode)
c = new cluster_client(this);
c = new cluster_client(this, conn_id);
else
c = new client(this);
c = new client(this, conn_id);

assert(c != NULL);

Expand Down
12 changes: 9 additions & 3 deletions client.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class client : public connections_manager {
// test related
benchmark_config* m_config;
object_generator* m_obj_gen;
std::string m_conn_id_str;
const char* m_conn_id_value;
unsigned int m_conn_id_value_len;
run_stats m_stats;

unsigned long long m_reqs_processed; // requests processed (responses received)
Expand All @@ -78,13 +81,14 @@ class client : public connections_manager {
keylist *m_keylist; // used to construct multi commands

public:
client(client_group* group);
client(struct event_base *event_base, benchmark_config *config, abstract_protocol *protocol, object_generator *obj_gen);
client(client_group* group, unsigned int conn_id = 0);
client(struct event_base *event_base, benchmark_config *config, abstract_protocol *protocol, object_generator *obj_gen, unsigned int conn_id = 0);
virtual ~client();
bool setup_client(benchmark_config *config, abstract_protocol *protocol, object_generator *obj_gen);
int prepare(void);
bool initialized(void);
run_stats* get_stats(void) { return &m_stats; }
const char* get_conn_id_value(void) { return m_conn_id_value; }

virtual get_key_response get_key_for_conn(unsigned int command_index, unsigned int conn_id, unsigned long long* key_index);
virtual bool create_arbitrary_request(unsigned int command_index, struct timeval& timestamp, unsigned int conn_id);
Expand Down Expand Up @@ -203,8 +207,10 @@ class client_group {
abstract_protocol* m_protocol;
object_generator* m_obj_gen;
std::vector<client*> m_clients;
protected:
unsigned int m_thread_id;
public:
client_group(benchmark_config *cfg, abstract_protocol *protocol, object_generator* obj_gen);
client_group(benchmark_config *cfg, abstract_protocol *protocol, object_generator* obj_gen, unsigned int thread_id);
~client_group();

int create_clients(int count);
Expand Down
11 changes: 7 additions & 4 deletions cluster_client.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <assert.h>
/*
* Copyright (C) 2011-2017 Redis Labs Ltd.
*
Expand Down Expand Up @@ -108,7 +109,7 @@ static uint32_t calc_hslot_crc16_cluster(const char *str, size_t length)

///////////////////////////////////////////////////////////////////////////////////////////////////////

cluster_client::cluster_client(client_group* group) : client(group)
cluster_client::cluster_client(client_group* group, unsigned int conn_id) : client(group, conn_id)
{
}

Expand Down Expand Up @@ -159,9 +160,11 @@ void cluster_client::disconnect(void)
}

shard_connection* cluster_client::create_shard_connection(abstract_protocol* abs_protocol) {
shard_connection* sc = new shard_connection(m_connections.size(), this,
m_config, m_event_base,
abs_protocol);
unsigned int conn_id = m_connections.size();
shard_connection* sc = new shard_connection(
conn_id, this, m_config, m_event_base, abs_protocol,
conn_id
);
assert(sc != NULL);

m_connections.push_back(sc);
Expand Down
2 changes: 1 addition & 1 deletion cluster_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class cluster_client : public client {
request *request, protocol_response *response);

public:
cluster_client(client_group* group);
cluster_client(client_group* group, unsigned int conn_id);
virtual ~cluster_client();

virtual get_key_response get_key_for_conn(unsigned int command_index, unsigned int conn_id, unsigned long long* key_index);
Expand Down
4 changes: 3 additions & 1 deletion config_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,14 @@ struct server_addr {

#define KEY_PLACEHOLDER "__key__"
#define DATA_PLACEHOLDER "__data__"
#define CONN_PLACEHOLDER "__conn_id__"

enum command_arg_type {
const_type = 0,
key_type = 1,
data_type = 2,
undefined_type = 3
conn_id_type = 3,
undefined_type = 4
};

struct command_arg {
Expand Down
3 changes: 2 additions & 1 deletion memtier_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,8 @@ struct cg_thread {
m_protocol = protocol_factory(m_config->protocol);
assert(m_protocol != NULL);

m_cg = new client_group(m_config, m_protocol, m_obj_gen);
// Pass thread_id to client_group
m_cg = new client_group(m_config, m_protocol, m_obj_gen, m_thread_id);
}

~cg_thread()
Expand Down
31 changes: 15 additions & 16 deletions protocol.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <assert.h>
/*
* Copyright (C) 2011-2017 Redis Labs Ltd.
*
Expand Down Expand Up @@ -175,7 +176,7 @@ class redis_protocol : public abstract_protocol {
redis_protocol() : m_response_state(rs_initial), m_bulk_len(0), m_response_len(0), m_total_bulks_count(0), m_current_mbulk(NULL), m_resp3(false), m_attribute(false) { }
virtual redis_protocol* clone(void) { return new redis_protocol(); }
virtual int select_db(int db);
virtual int authenticate(const char *credentials);
virtual int authenticate(const char *user, const char *credentials);
virtual int configure_protocol(enum PROTOCOL_TYPE type);
virtual int write_command_cluster_slots();
virtual int write_command_set(const char *key, int key_len, const char *value, int value_len, int expiry, unsigned int offset);
Expand Down Expand Up @@ -206,7 +207,7 @@ int redis_protocol::select_db(int db)
return size;
}

int redis_protocol::authenticate(const char *credentials)
int redis_protocol::authenticate(const char *user, const char *credentials)
{
int size = 0;
assert(credentials != NULL);
Expand All @@ -219,7 +220,6 @@ int redis_protocol::authenticate(const char *credentials)
* contains a colon.
*/

const char *user = NULL;
const char *password;

if (credentials[0] == ':') {
Expand All @@ -229,12 +229,11 @@ int redis_protocol::authenticate(const char *credentials)
if (!password) {
password = credentials;
} else {
user = credentials;
password++;
}
}

if (!user) {
if (!user || strlen(user) == 0) {
size = evbuffer_add_printf(m_write_buf,
"*2\r\n"
"$4\r\n"
Expand All @@ -243,17 +242,16 @@ int redis_protocol::authenticate(const char *credentials)
"%s\r\n",
strlen(password), password);
} else {
size_t user_len = password - user - 1;
size_t user_len = strlen(user);
size = evbuffer_add_printf(m_write_buf,
"*3\r\n"
"$4\r\n"
"AUTH\r\n"
"$%zu\r\n"
"%.*s\r\n"
"%s\r\n"
"$%zu\r\n"
"%s\r\n",
user_len,
(int) user_len,
user,
strlen(password),
password);
Expand Down Expand Up @@ -723,8 +721,10 @@ bool redis_protocol::format_arbitrary_command(arbitrary_command &cmd) {
benchmark_error_log("error: data placeholder can't combined with other data\n");
return false;
}

current_arg->type = data_type;
} else if (current_arg->data.find(CONN_PLACEHOLDER) != std::string::npos) {
// Allow conn_id placeholder to be combined with other text
current_arg->type = conn_id_type;
}

// we expect that first arg is the COMMAND name
Expand Down Expand Up @@ -761,7 +761,7 @@ class memcache_text_protocol : public abstract_protocol {
memcache_text_protocol() : m_response_state(rs_initial), m_value_len(0), m_response_len(0) { }
virtual memcache_text_protocol* clone(void) { return new memcache_text_protocol(); }
virtual int select_db(int db);
virtual int authenticate(const char *credentials);
virtual int authenticate(const char *user, const char *credentials);
virtual int configure_protocol(enum PROTOCOL_TYPE type);
virtual int write_command_cluster_slots();
virtual int write_command_set(const char *key, int key_len, const char *value, int value_len, int expiry, unsigned int offset);
Expand All @@ -782,7 +782,7 @@ int memcache_text_protocol::select_db(int db)
assert(0);
}

int memcache_text_protocol::authenticate(const char *credentials)
int memcache_text_protocol::authenticate(const char *user, const char *credentials)
{
assert(0);
}
Expand Down Expand Up @@ -983,7 +983,7 @@ class memcache_binary_protocol : public abstract_protocol {
memcache_binary_protocol() : m_response_state(rs_initial), m_response_len(0) { }
virtual memcache_binary_protocol* clone(void) { return new memcache_binary_protocol(); }
virtual int select_db(int db);
virtual int authenticate(const char *credentials);
virtual int authenticate(const char *user, const char *credentials);
virtual int configure_protocol(enum PROTOCOL_TYPE type);
virtual int write_command_cluster_slots();
virtual int write_command_set(const char *key, int key_len, const char *value, int value_len, int expiry, unsigned int offset);
Expand All @@ -1003,14 +1003,13 @@ int memcache_binary_protocol::select_db(int db)
assert(0);
}

int memcache_binary_protocol::authenticate(const char *credentials)
int memcache_binary_protocol::authenticate(const char *user, const char *credentials)
{
protocol_binary_request_no_extras req;
char nullbyte = '\0';
const char mechanism[] = "PLAIN";
int mechanism_len = sizeof(mechanism) - 1;
const char *colon;
const char *user;
int user_len;
const char *passwd;
int passwd_len;
Expand All @@ -1019,8 +1018,8 @@ int memcache_binary_protocol::authenticate(const char *credentials)
colon = strchr(credentials, ':');
assert(colon != NULL);

user = credentials;
user_len = colon - user;
// Use the user parameter instead of extracting from credentials
user_len = strlen(user);
passwd = colon + 1;
passwd_len = strlen(passwd);

Expand Down
2 changes: 1 addition & 1 deletion protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class abstract_protocol {
void set_keep_value(bool flag);

virtual int select_db(int db) = 0;
virtual int authenticate(const char *credentials) = 0;
virtual int authenticate(const char *user, const char *credentials) = 0;
virtual int configure_protocol(enum PROTOCOL_TYPE type) = 0;
virtual int write_command_cluster_slots() = 0;
virtual int write_command_set(const char *key, int key_len, const char *value, int value_len, int expiry, unsigned int offset) = 0;
Expand Down
Loading
Loading