Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change stream list to arrays and optimize stream lookup with SSE2 #508

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 197 additions & 34 deletions srtp/srtp.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,22 @@
#include "aes_icm_ext.h"
#endif

#include <stddef.h>
#include <string.h>
#include <limits.h>
#ifdef HAVE_NETINET_IN_H
#include <netinet/in.h>
#elif defined(HAVE_WINSOCK2_H)
#include <winsock2.h>
#endif

#if defined(__SSE2__)
#include <emmintrin.h>
#if defined(_MSC_VER)
#include <intrin.h>
#endif
#endif

/* the debug module for srtp */
srtp_debug_module_t mod_srtp = {
0, /* debugging is off by default */
Expand All @@ -79,6 +88,16 @@ srtp_debug_module_t mod_srtp = {
#define uint32s_in_rtcp_header 2
#define octets_in_rtp_extn_hdr 4

#ifndef SRTP_NO_STREAM_LIST
static inline uint32_t srtp_stream_list_size(srtp_stream_list_t list);
static srtp_err_status_t srtp_stream_list_reserve(srtp_stream_list_t list,
uint32_t new_capacity);
static uint32_t srtp_stream_list_find(srtp_stream_list_t list, uint32_t ssrc);
static inline srtp_stream_t srtp_stream_list_get_at(srtp_stream_list_t list,
uint32_t pos);
static void srtp_stream_list_remove_at(srtp_stream_list_t list, uint32_t pos);
#endif // SRTP_NO_STREAM_LIST

static srtp_err_status_t srtp_validate_rtp_header(void *rtp_hdr,
int *pkt_octet_len)
{
Expand Down Expand Up @@ -3030,18 +3049,31 @@ srtp_err_status_t srtp_remove_stream(srtp_t session, uint32_t ssrc)
{
srtp_stream_ctx_t *stream;
srtp_err_status_t status;
#if !defined(SRTP_NO_STREAM_LIST)
uint32_t pos;
#endif

/* sanity check arguments */
if (session == NULL)
if (session == NULL) {
return srtp_err_status_bad_param;
}

/* find and remove stream from the list */
#if !defined(SRTP_NO_STREAM_LIST)
pos = srtp_stream_list_find(session->stream_list, ssrc);
if (pos >= srtp_stream_list_size(session->stream_list))
return srtp_err_status_no_ctx;

stream = srtp_stream_list_get_at(session->stream_list, pos);
srtp_stream_list_remove_at(session->stream_list, pos);
#else
stream = srtp_stream_list_get(session->stream_list, ssrc);
if (stream == NULL) {
return srtp_err_status_no_ctx;
}

srtp_stream_list_remove(session->stream_list, stream);
#endif

/* deallocate the stream */
status = srtp_stream_dealloc(stream, session->stream_template);
Lastique marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -4840,11 +4872,11 @@ srtp_err_status_t srtp_get_stream_roc(srtp_t session,

#ifndef SRTP_NO_STREAM_LIST

/* in the default implementation, we have an intrusive doubly-linked list */
typedef struct srtp_stream_list_ctx_t_ {
/* a stub stream that just holds pointers to the beginning and end of the
* list */
srtp_stream_ctx_t data;
uint32_t *ssrcs;
srtp_stream_ctx_t **streams;
uint32_t size;
uint32_t capacity;
} srtp_stream_list_ctx_t_;

srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr)
Expand All @@ -4855,73 +4887,204 @@ srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr)
return srtp_err_status_alloc_fail;
}

list->data.next = NULL;
list->data.prev = NULL;

*list_ptr = list;
return srtp_err_status_ok;
}

srtp_err_status_t srtp_stream_list_dealloc(srtp_stream_list_t list)
{
/* list must be empty */
if (list->data.next) {
if (list->size != 0u) {
return srtp_err_status_fail;
}
srtp_crypto_free(list->streams);
srtp_crypto_free(list->ssrcs);
srtp_crypto_free(list);
return srtp_err_status_ok;
}

static inline uint32_t srtp_stream_list_size(srtp_stream_list_t list)
{
return list->size;
}

static srtp_err_status_t srtp_stream_list_reserve(srtp_stream_list_t list,
uint32_t new_capacity)
{
if (new_capacity > list->capacity) {
uint32_t *ssrcs;
srtp_stream_ctx_t **stream_ptrs;

if (new_capacity > (UINT32_MAX - 15u))
return srtp_err_status_alloc_fail;

new_capacity = (new_capacity + 15u) & ~((uint32_t)15u);

ssrcs = (uint32_t *)srtp_crypto_alloc((size_t)new_capacity *
sizeof(uint32_t));
if (!ssrcs)
return srtp_err_status_alloc_fail;
stream_ptrs = (srtp_stream_ctx_t **)srtp_crypto_alloc(
(size_t)new_capacity * sizeof(srtp_stream_ctx_t *));
if (!stream_ptrs) {
srtp_crypto_free(ssrcs);
return srtp_err_status_alloc_fail;
}

if (list->size > 0u) {
memcpy(ssrcs, list->ssrcs, (size_t)list->size * sizeof(uint32_t));
memcpy(stream_ptrs, list->streams,
(size_t)list->size * sizeof(srtp_stream_ctx_t *));
}

srtp_crypto_free(list->ssrcs);
srtp_crypto_free(list->streams);
list->streams = stream_ptrs;
list->ssrcs = ssrcs;

list->capacity = new_capacity;
}

return srtp_err_status_ok;
}

srtp_err_status_t srtp_stream_list_insert(srtp_stream_list_t list,
srtp_stream_t stream)
{
/* insert at the head of the list */
stream->next = list->data.next;
if (stream->next != NULL) {
stream->next->prev = stream;
}
list->data.next = stream;
stream->prev = &(list->data);
uint32_t pos;
srtp_err_status_t status = srtp_stream_list_reserve(list, list->size + 1u);
if (status)
return status;
pos = list->size++;
list->ssrcs[pos] = stream->ssrc;
list->streams[pos] = stream;

return srtp_err_status_ok;
}

srtp_stream_t srtp_stream_list_get(srtp_stream_list_t list, uint32_t ssrc)
static uint32_t srtp_stream_list_find(srtp_stream_list_t list, uint32_t ssrc)
{
/* walk down list until ssrc is found */
srtp_stream_t stream = list->data.next;
while (stream != NULL) {
if (stream->ssrc == ssrc) {
return stream;
#if defined(__SSE2__)
const uint32_t *const ssrcs = list->ssrcs;
const __m128i mm_ssrc = _mm_set1_epi32(ssrc);
uint32_t pos = 0u, n = (list->size + 7u) & ~(uint32_t)(7u);
for (uint32_t m = n & ~(uint32_t)(15u); pos < m; pos += 16u) {
__m128i mm1 = _mm_loadu_si128((const __m128i *)(ssrcs + pos));
__m128i mm2 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 4u));
__m128i mm3 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 8u));
__m128i mm4 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 12u));
mm1 = _mm_cmpeq_epi32(mm1, mm_ssrc);
mm2 = _mm_cmpeq_epi32(mm2, mm_ssrc);
mm3 = _mm_cmpeq_epi32(mm3, mm_ssrc);
mm4 = _mm_cmpeq_epi32(mm4, mm_ssrc);
mm1 = _mm_packs_epi32(mm1, mm2);
mm3 = _mm_packs_epi32(mm3, mm4);
mm1 = _mm_packs_epi16(mm1, mm3);
uint32_t mask = _mm_movemask_epi8(mm1);
if (mask) {
#if defined(_MSC_VER)
unsigned long bit_pos;
_BitScanForward(&bit_pos, mask);
pos += bit_pos;
#else
pos += __builtin_ctz(mask);
#endif

goto done;
}
}

if (pos < n) {
__m128i mm1 = _mm_loadu_si128((const __m128i *)(ssrcs + pos));
__m128i mm2 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 4u));
mm1 = _mm_cmpeq_epi32(mm1, mm_ssrc);
mm2 = _mm_cmpeq_epi32(mm2, mm_ssrc);
mm1 = _mm_packs_epi32(mm1, mm2);

uint32_t mask = _mm_movemask_epi8(mm1);
if (mask) {
#if defined(_MSC_VER)
unsigned long bit_pos;
_BitScanForward(&bit_pos, mask);
pos += bit_pos / 2u;
#else
pos += __builtin_ctz(mask) / 2u;
#endif
goto done;
}
stream = stream->next;

pos += 8u;
}

done:
return pos;
#else
/* walk down list until ssrc is found */
uint32_t pos = 0u, n = list->size;
for (; pos < n; ++pos) {
if (list->ssrcs[pos] == ssrc)
break;
}

return pos;
#endif
}

static inline srtp_stream_t srtp_stream_list_get_at(srtp_stream_list_t list,
uint32_t pos)
{
return list->streams[pos];
}

srtp_stream_t srtp_stream_list_get(srtp_stream_list_t list, uint32_t ssrc)
{
uint32_t pos = srtp_stream_list_find(list, ssrc);
if (pos < list->size)
return list->streams[pos];

/* we haven't found our ssrc, so return a null */
return NULL;
}

void srtp_stream_list_remove(srtp_stream_list_t list,
srtp_stream_t stream_to_remove)
static void srtp_stream_list_remove_at(srtp_stream_list_t list, uint32_t pos)
{
(void)list;
uint32_t tail_size, last_pos;

stream_to_remove->prev->next = stream_to_remove->next;
if (stream_to_remove->next != NULL) {
stream_to_remove->next->prev = stream_to_remove->prev;
last_pos = --list->size;
tail_size = last_pos - pos;
if (tail_size > 0u) {
memmove(list->streams + pos, list->streams + pos + 1,
(size_t)tail_size * sizeof(*list->streams));
memmove(list->ssrcs + pos, list->ssrcs + pos + 1,
(size_t)tail_size * sizeof(*list->ssrcs));
}

list->streams[last_pos] = NULL;
list->ssrcs[last_pos] = 0u;
}

void srtp_stream_list_remove(srtp_stream_list_t list,
srtp_stream_t stream_to_remove)
{
uint32_t pos = srtp_stream_list_find(list, stream_to_remove->ssrc);
if (pos < list->size)
srtp_stream_list_remove_at(list, pos);
}

void srtp_stream_list_for_each(srtp_stream_list_t list,
int (*callback)(srtp_stream_t, void *),
void *data)
{
srtp_stream_t stream = list->data.next;
while (stream != NULL) {
srtp_stream_t tmp = stream;
stream = stream->next;
if (callback(tmp, data))
uint32_t size = list->size;
for (uint32_t i = 0u; i < size;) {
if (callback(list->streams[i], data))
break;

/* check if the callback removed the current element */
if (size == list->size)
++i;
else
size = list->size;
}
}

Expand Down