Skip to content

Commit

Permalink
rdma: add lock for buffers (#858)
Browse files Browse the repository at this point in the history
Signed-off-by: Ric Li <[email protected]>
  • Loading branch information
ricmli authored May 13, 2024
1 parent 587c4a8 commit 2c6c1f9
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 5 deletions.
2 changes: 2 additions & 0 deletions rdma/mt_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ struct mt_rdma_tx_buffer {
uint32_t remote_key;
uint64_t remote_addr;
uint32_t ref_count;
pthread_mutex_t lock;
};

struct mt_rdma_tx_ctx {
Expand Down Expand Up @@ -153,6 +154,7 @@ struct mt_rdma_rx_buffer {
enum mt_rdma_buffer_status status;
struct mtl_rdma_buffer buffer;
struct ibv_mr* mr;
pthread_mutex_t lock;
};

struct mt_rdma_rx_ctx {
Expand Down
24 changes: 23 additions & 1 deletion rdma/mt_rdma_rx.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,23 @@ static struct mt_rdma_message* rdma_rx_get_recv_msg(struct mt_rdma_rx_ctx* ctx)

static int rdma_rx_send_buffer_done(struct mt_rdma_rx_ctx* ctx, uint16_t idx) {
struct mt_rdma_rx_buffer* rx_buffer = &ctx->rx_buffers[idx];
pthread_mutex_lock(&rx_buffer->lock);
struct mt_rdma_message msg = {
.magic = MT_RDMA_MSG_MAGIC,
.type = MT_RDMA_MSG_BUFFER_DONE,
.buf_done.buf_idx = idx,
.buf_done.seq_num = 0, /* todo */
.buf_done.seq_num = rx_buffer->buffer.seq_num,
.buf_done.rx_buf_addr = (uint64_t)rx_buffer->buffer.addr,
.buf_done.rx_buf_key = rx_buffer->mr->rkey,
};
int ret = rdma_post_send(ctx->id, NULL, &msg, sizeof(msg), NULL, IBV_SEND_INLINE);
if (ret) {
err("%s(%s), rdma_post_send failed: %s\n", __func__, ctx->ops_name, strerror(errno));
pthread_mutex_unlock(&rx_buffer->lock);
return -EIO;
}
rx_buffer->status = MT_RDMA_BUFFER_STATUS_FREE;
pthread_mutex_unlock(&rx_buffer->lock);
return 0;
}

Expand Down Expand Up @@ -78,6 +81,9 @@ static int rdma_rx_init_mrs(struct mt_rdma_rx_ctx* ctx) {
static int rdma_rx_free_buffers(struct mt_rdma_rx_ctx* ctx) {
rdma_rx_uinit_mrs(ctx);
MT_SAFE_FREE(ctx->recv_msgs, free);
for (int i = 0; i < ctx->buffer_cnt; i++) {
pthread_mutex_destroy(&ctx->rx_buffers[i].lock);
}
MT_SAFE_FREE(ctx->rx_buffers, free);
return 0;
}
Expand All @@ -98,6 +104,7 @@ static int rdma_rx_alloc_buffers(struct mt_rdma_rx_ctx* ctx) {
rx_buffer->status = MT_RDMA_BUFFER_STATUS_FREE;
rx_buffer->buffer.addr = ops->buffers[i];
rx_buffer->buffer.capacity = ops->buffer_capacity;
pthread_mutex_init(&rx_buffer->lock, NULL);
}

/* alloc receive message region including metadata, send messages are inlined */
Expand Down Expand Up @@ -173,6 +180,7 @@ static void* rdma_rx_cq_poll_thread(void* arg) {
idx = msg->buf_meta.buf_idx;
dbg("%s(%s), buffer %u meta received\n", __func__, ctx->ops_name, idx);
rx_buffer = &ctx->rx_buffers[idx];
pthread_mutex_lock(&rx_buffer->lock);
rx_buffer->buffer.user_meta =
(void*)(msg + 1); /* this msg buffer in use by meta */
rx_buffer->buffer.user_meta_size = msg->buf_meta.meta_size;
Expand All @@ -188,17 +196,21 @@ static void* rdma_rx_cq_poll_thread(void* arg) {
}
}
} else if (rx_buffer->status == MT_RDMA_BUFFER_STATUS_FREE) {
rx_buffer->buffer.seq_num = msg->buf_meta.seq_num;
rx_buffer->status = MT_RDMA_BUFFER_STATUS_IN_TRANSMISSION;
} else {
err("%s(%s), buffer %u unexpected status %d\n", __func__, ctx->ops_name,
idx, rx_buffer->status);
pthread_mutex_unlock(&rx_buffer->lock);
goto out;
}
pthread_mutex_unlock(&rx_buffer->lock);
break;
case MT_RDMA_MSG_BUFFER_READY:
idx = msg->buf_ready.buf_idx;
dbg("%s(%s), buffer %u ready received\n", __func__, ctx->ops_name, idx);
rx_buffer = &ctx->rx_buffers[idx];
pthread_mutex_lock(&rx_buffer->lock);
if (rx_buffer->status == MT_RDMA_BUFFER_STATUS_IN_TRANSMISSION) {
rx_buffer->status = MT_RDMA_BUFFER_STATUS_READY;
ctx->stat_buffer_received++;
Expand All @@ -211,12 +223,15 @@ static void* rdma_rx_cq_poll_thread(void* arg) {
}
}
} else if (rx_buffer->status == MT_RDMA_BUFFER_STATUS_FREE) {
rx_buffer->buffer.seq_num = msg->buf_ready.seq_num;
rx_buffer->status = MT_RDMA_BUFFER_STATUS_WAIT_META;
} else {
err("%s(%s), buffer %u unexpected status %d\n", __func__, ctx->ops_name,
idx, rx_buffer->status);
pthread_mutex_unlock(&rx_buffer->lock);
goto out;
}
pthread_mutex_unlock(&rx_buffer->lock);
msg->type = MT_RDMA_MSG_NONE; /* recycle receive msg */
break;
case MT_RDMA_MSG_BYE:
Expand Down Expand Up @@ -400,10 +415,13 @@ struct mtl_rdma_buffer* mtl_rdma_rx_get_buffer(mtl_rdma_rx_handle handle) {
/* find a ready buffer */
for (int i = 0; i < ctx->buffer_cnt; i++) {
struct mt_rdma_rx_buffer* rx_buffer = &ctx->rx_buffers[i];
pthread_mutex_lock(&rx_buffer->lock);
if (rx_buffer->status == MT_RDMA_BUFFER_STATUS_READY) {
rx_buffer->status = MT_RDMA_BUFFER_STATUS_IN_CONSUMPTION;
pthread_mutex_unlock(&rx_buffer->lock);
return &rx_buffer->buffer;
}
pthread_mutex_unlock(&rx_buffer->lock);
}

return NULL;
Expand All @@ -417,17 +435,21 @@ int mtl_rdma_rx_put_buffer(mtl_rdma_rx_handle handle, struct mtl_rdma_buffer* bu

for (int i = 0; i < ctx->buffer_cnt; i++) {
struct mt_rdma_rx_buffer* rx_buffer = &ctx->rx_buffers[i];
pthread_mutex_lock(&rx_buffer->lock);
if (&rx_buffer->buffer == buffer) {
if (rx_buffer->status != MT_RDMA_BUFFER_STATUS_IN_CONSUMPTION) {
err("%s(%s), buffer %p not in consumption\n", __func__, ctx->ops_name, buffer);
pthread_mutex_unlock(&rx_buffer->lock);
return -EIO;
}
/* recycle meta in use receive msg */
struct mt_rdma_message* meta_msg =
(struct mt_rdma_message*)rx_buffer->buffer.user_meta - 1;
meta_msg->type = MT_RDMA_MSG_NONE;
pthread_mutex_unlock(&rx_buffer->lock);
return rdma_rx_send_buffer_done(ctx, rx_buffer->idx);
}
pthread_mutex_unlock(&rx_buffer->lock);
}

err("%s(%s), buffer %p not found\n", __func__, ctx->ops_name, buffer);
Expand Down
32 changes: 28 additions & 4 deletions rdma/mt_rdma_tx.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ static int rdma_tx_free_buffers(struct mt_rdma_tx_ctx* ctx) {
rdma_tx_uinit_mrs(ctx);
MT_SAFE_FREE(ctx->send_msgs, free);
MT_SAFE_FREE(ctx->recv_msgs, free);
for (int i = 0; i < ctx->buffer_cnt; i++) {
pthread_mutex_destroy(&ctx->tx_buffers[i].lock);
}
MT_SAFE_FREE(ctx->tx_buffers, free);
return 0;
}
Expand All @@ -80,6 +83,7 @@ static int rdma_tx_alloc_buffers(struct mt_rdma_tx_ctx* ctx) {
tx_buffer->ref_count = 1;
tx_buffer->buffer.addr = ops->buffers[i];
tx_buffer->buffer.capacity = ops->buffer_capacity;
pthread_mutex_init(&tx_buffer->lock, NULL);
}

/* alloc receive message region */
Expand Down Expand Up @@ -158,12 +162,14 @@ static void* rdma_tx_cq_poll_thread(void* arg) {
if (msg->magic == MT_RDMA_MSG_MAGIC) {
if (msg->type == MT_RDMA_MSG_BUFFER_DONE) {
uint16_t idx = msg->buf_done.buf_idx;
dbg("%s(%s), received buffer %u done message\n", __func__, ctx->ops_name,
idx);
dbg("%s(%s), received buffer %u done message, seq %u\n", __func__,
ctx->ops_name, idx, msg->buf_done.seq_num);
struct mt_rdma_tx_buffer* tx_buffer = &ctx->tx_buffers[idx];
pthread_mutex_lock(&tx_buffer->lock);
if (tx_buffer->status != MT_RDMA_BUFFER_STATUS_IN_CONSUMPTION) {
err("%s(%s), received buffer done message with invalid status %d\n",
__func__, ctx->ops_name, tx_buffer->status);
pthread_mutex_unlock(&tx_buffer->lock);
goto out;
}
tx_buffer->remote_addr = msg->buf_done.rx_buf_addr;
Expand All @@ -179,6 +185,7 @@ static void* rdma_tx_cq_poll_thread(void* arg) {
}
}
}
pthread_mutex_unlock(&tx_buffer->lock);
ctx->stat_buffer_acked++;
}
} else if (msg->type == MT_RDMA_MSG_BYE) {
Expand All @@ -196,17 +203,19 @@ static void* rdma_tx_cq_poll_thread(void* arg) {
}
} else if (wc.opcode == IBV_WC_RDMA_WRITE) {
struct mt_rdma_tx_buffer* tx_buffer = (struct mt_rdma_tx_buffer*)wc.wr_id;
/* send ready message to rx, todo add user meta with sgl */
pthread_mutex_lock(&tx_buffer->lock);
/* send ready message to rx */
struct mt_rdma_message msg = {
.magic = MT_RDMA_MSG_MAGIC,
.type = MT_RDMA_MSG_BUFFER_READY,
.buf_ready.buf_idx = tx_buffer->idx,
.buf_ready.seq_num = 0, /* todo */
.buf_ready.seq_num = ctx->buffer_seq_num++,
};
ret = rdma_post_send(ctx->id, NULL, &msg, sizeof(msg), NULL, IBV_SEND_INLINE);
if (ret) {
err("%s(%s), rdma_post_send failed: %s\n", __func__, ctx->ops_name,
strerror(errno));
pthread_mutex_unlock(&tx_buffer->lock);
goto out;
}
dbg("%s(%s), send buffer %d ready message\n", __func__, ctx->ops_name,
Expand All @@ -220,6 +229,7 @@ static void* rdma_tx_cq_poll_thread(void* arg) {
/* todo: error handle */
}
}
pthread_mutex_unlock(&tx_buffer->lock);
ctx->stat_buffer_sent++;
} else if (wc.opcode == IBV_WC_SEND) {
if (wc.wr_id == MT_RDMA_MSG_BYE) {
Expand Down Expand Up @@ -376,10 +386,13 @@ struct mtl_rdma_buffer* mtl_rdma_tx_get_buffer(mtl_rdma_tx_handle handle) {
/* change to use buffer_producer_idx to act as a queue */
for (int i = 0; i < ctx->buffer_cnt; i++) {
struct mt_rdma_tx_buffer* tx_buffer = &ctx->tx_buffers[i];
pthread_mutex_lock(&tx_buffer->lock);
if (tx_buffer->status == MT_RDMA_BUFFER_STATUS_FREE) {
tx_buffer->status = MT_RDMA_BUFFER_STATUS_IN_PRODUCTION;
pthread_mutex_unlock(&tx_buffer->lock);
return &tx_buffer->buffer;
}
pthread_mutex_unlock(&tx_buffer->lock);
}
return NULL;
}
Expand All @@ -390,16 +403,23 @@ int mtl_rdma_tx_put_buffer(mtl_rdma_tx_handle handle, struct mtl_rdma_buffer* bu
return -EIO;
}

if (buffer->size > buffer->capacity) {
err("%s(%s), buffer size is too large\n", __func__, ctx->ops_name);
return -EIO;
}

if (buffer->user_meta_size > MT_RDMA_USER_META_MAX_SIZE) {
err("%s(%s), user meta size is too large\n", __func__, ctx->ops_name);
return -EIO;
}

for (int i = 0; i < ctx->buffer_cnt; i++) {
struct mt_rdma_tx_buffer* tx_buffer = &ctx->tx_buffers[i];
pthread_mutex_lock(&tx_buffer->lock);
if (&tx_buffer->buffer == buffer) {
if (tx_buffer->status != MT_RDMA_BUFFER_STATUS_IN_PRODUCTION) {
err("%s(%s), buffer %p is not in production\n", __func__, ctx->ops_name, buffer);
pthread_mutex_unlock(&tx_buffer->lock);
return -EIO;
}
/* write to rx immediately */
Expand All @@ -409,13 +429,15 @@ int mtl_rdma_tx_put_buffer(mtl_rdma_tx_handle handle, struct mtl_rdma_buffer* bu
if (ret) {
err("%s(%s), rdma_post_write failed: %s\n", __func__, ctx->ops_name,
strerror(errno));
pthread_mutex_unlock(&tx_buffer->lock);
return -EIO;
}
/* send user metadata to rx */
struct mt_rdma_message* msg = ctx->send_msgs + i * MT_RDMA_MSG_MAX_SIZE;
msg->magic = MT_RDMA_MSG_MAGIC;
msg->type = MT_RDMA_MSG_BUFFER_META;
msg->buf_meta.buf_idx = i;
msg->buf_meta.seq_num = ctx->buffer_seq_num;
msg->buf_meta.meta_size = buffer->user_meta_size;
memcpy(&msg[1], buffer->user_meta, buffer->user_meta_size);
ret = rdma_post_send(ctx->id, msg, msg, MT_RDMA_MSG_MAX_SIZE, ctx->send_msgs_mr,
Expand All @@ -428,8 +450,10 @@ int mtl_rdma_tx_put_buffer(mtl_rdma_tx_handle handle, struct mtl_rdma_buffer* bu
dbg("%s(%s), send meta for buffer %d\n", __func__, ctx->ops_name, i);

tx_buffer->status = MT_RDMA_BUFFER_STATUS_IN_TRANSMISSION;
pthread_mutex_unlock(&tx_buffer->lock);
return 0;
}
pthread_mutex_unlock(&tx_buffer->lock);
}

err("%s(%s), buffer %p not found\n", __func__, ctx->ops_name, buffer);
Expand Down

0 comments on commit 2c6c1f9

Please sign in to comment.