diff --git a/srtp/srtp.c b/srtp/srtp.c index 75b71fb6e..14e41d5cc 100644 --- a/srtp/srtp.c +++ b/srtp/srtp.c @@ -60,6 +60,8 @@ #include "aes_icm_ext.h" #endif +#include +#include #include #ifdef HAVE_NETINET_IN_H #include @@ -67,6 +69,13 @@ #include #endif +#if defined(__SSE2__) +#include +#if defined(_MSC_VER) +#include +#endif +#endif + /* the debug module for srtp */ srtp_debug_module_t mod_srtp = { 0, /* debugging is off by default */ @@ -79,6 +88,16 @@ srtp_debug_module_t mod_srtp = { #define uint32s_in_rtcp_header 2 #define octets_in_rtp_extn_hdr 4 +#ifndef SRTP_NO_STREAM_LIST +static inline uint32_t srtp_stream_list_size(srtp_stream_list_t list); +static srtp_err_status_t srtp_stream_list_reserve(srtp_stream_list_t list, + uint32_t new_capacity); +static uint32_t srtp_stream_list_find(srtp_stream_list_t list, uint32_t ssrc); +static inline srtp_stream_t srtp_stream_list_get_at(srtp_stream_list_t list, + uint32_t pos); +static void srtp_stream_list_remove_at(srtp_stream_list_t list, uint32_t pos); +#endif // SRTP_NO_STREAM_LIST + static srtp_err_status_t srtp_validate_rtp_header(void *rtp_hdr, int *pkt_octet_len) { @@ -3030,18 +3049,31 @@ srtp_err_status_t srtp_remove_stream(srtp_t session, uint32_t ssrc) { srtp_stream_ctx_t *stream; srtp_err_status_t status; +#if !defined(SRTP_NO_STREAM_LIST) + uint32_t pos; +#endif /* sanity check arguments */ - if (session == NULL) + if (session == NULL) { return srtp_err_status_bad_param; + } /* find and remove stream from the list */ +#if !defined(SRTP_NO_STREAM_LIST) + pos = srtp_stream_list_find(session->stream_list, ssrc); + if (pos >= srtp_stream_list_size(session->stream_list)) + return srtp_err_status_no_ctx; + + stream = srtp_stream_list_get_at(session->stream_list, pos); + srtp_stream_list_remove_at(session->stream_list, pos); +#else stream = srtp_stream_list_get(session->stream_list, ssrc); if (stream == NULL) { return srtp_err_status_no_ctx; } srtp_stream_list_remove(session->stream_list, stream); +#endif /* deallocate the stream */ status = srtp_stream_dealloc(stream, session->stream_template); @@ -4840,11 +4872,11 @@ srtp_err_status_t srtp_get_stream_roc(srtp_t session, #ifndef SRTP_NO_STREAM_LIST -/* in the default implementation, we have an intrusive doubly-linked list */ typedef struct srtp_stream_list_ctx_t_ { - /* a stub stream that just holds pointers to the beginning and end of the - * list */ - srtp_stream_ctx_t data; + uint32_t *ssrcs; + srtp_stream_ctx_t **streams; + uint32_t size; + uint32_t capacity; } srtp_stream_list_ctx_t_; srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr) @@ -4855,9 +4887,6 @@ srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr) return srtp_err_status_alloc_fail; } - list->data.next = NULL; - list->data.prev = NULL; - *list_ptr = list; return srtp_err_status_ok; } @@ -4865,63 +4894,197 @@ srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr) srtp_err_status_t srtp_stream_list_dealloc(srtp_stream_list_t list) { /* list must be empty */ - if (list->data.next) { + if (list->size != 0u) { return srtp_err_status_fail; } + srtp_crypto_free(list->streams); + srtp_crypto_free(list->ssrcs); srtp_crypto_free(list); return srtp_err_status_ok; } +static inline uint32_t srtp_stream_list_size(srtp_stream_list_t list) +{ + return list->size; +} + +static srtp_err_status_t srtp_stream_list_reserve(srtp_stream_list_t list, + uint32_t new_capacity) +{ + if (new_capacity > list->capacity) { + uint32_t *ssrcs; + srtp_stream_ctx_t **stream_ptrs; + + if (new_capacity > (UINT32_MAX - 15u)) + return srtp_err_status_alloc_fail; + + new_capacity = (new_capacity + 15u) & ~((uint32_t)15u); + + ssrcs = (uint32_t *)srtp_crypto_alloc((size_t)new_capacity * + sizeof(uint32_t)); + if (!ssrcs) + return srtp_err_status_alloc_fail; + stream_ptrs = (srtp_stream_ctx_t **)srtp_crypto_alloc( + (size_t)new_capacity * sizeof(srtp_stream_ctx_t *)); + if (!stream_ptrs) { + srtp_crypto_free(ssrcs); + return srtp_err_status_alloc_fail; + } + + if (list->size > 0u) { + memcpy(ssrcs, list->ssrcs, (size_t)list->size * sizeof(uint32_t)); + memcpy(stream_ptrs, list->streams, + (size_t)list->size * sizeof(srtp_stream_ctx_t *)); + } + + srtp_crypto_free(list->ssrcs); + srtp_crypto_free(list->streams); + list->streams = stream_ptrs; + list->ssrcs = ssrcs; + + list->capacity = new_capacity; + } + + return srtp_err_status_ok; +} + srtp_err_status_t srtp_stream_list_insert(srtp_stream_list_t list, srtp_stream_t stream) { - /* insert at the head of the list */ - stream->next = list->data.next; - if (stream->next != NULL) { - stream->next->prev = stream; - } - list->data.next = stream; - stream->prev = &(list->data); + uint32_t pos; + srtp_err_status_t status = srtp_stream_list_reserve(list, list->size + 1u); + if (status) + return status; + pos = list->size++; + list->ssrcs[pos] = stream->ssrc; + list->streams[pos] = stream; return srtp_err_status_ok; } -srtp_stream_t srtp_stream_list_get(srtp_stream_list_t list, uint32_t ssrc) +static uint32_t srtp_stream_list_find(srtp_stream_list_t list, uint32_t ssrc) { - /* walk down list until ssrc is found */ - srtp_stream_t stream = list->data.next; - while (stream != NULL) { - if (stream->ssrc == ssrc) { - return stream; +#if defined(__SSE2__) + const uint32_t *const ssrcs = list->ssrcs; + const __m128i mm_ssrc = _mm_set1_epi32(ssrc); + uint32_t pos = 0u, n = (list->size + 7u) & ~(uint32_t)(7u); + for (uint32_t m = n & ~(uint32_t)(15u); pos < m; pos += 16u) { + __m128i mm1 = _mm_loadu_si128((const __m128i *)(ssrcs + pos)); + __m128i mm2 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 4u)); + __m128i mm3 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 8u)); + __m128i mm4 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 12u)); + mm1 = _mm_cmpeq_epi32(mm1, mm_ssrc); + mm2 = _mm_cmpeq_epi32(mm2, mm_ssrc); + mm3 = _mm_cmpeq_epi32(mm3, mm_ssrc); + mm4 = _mm_cmpeq_epi32(mm4, mm_ssrc); + mm1 = _mm_packs_epi32(mm1, mm2); + mm3 = _mm_packs_epi32(mm3, mm4); + mm1 = _mm_packs_epi16(mm1, mm3); + uint32_t mask = _mm_movemask_epi8(mm1); + if (mask) { +#if defined(_MSC_VER) + unsigned long bit_pos; + _BitScanForward(&bit_pos, mask); + pos += bit_pos; +#else + pos += __builtin_ctz(mask); +#endif + + goto done; + } + } + + if (pos < n) { + __m128i mm1 = _mm_loadu_si128((const __m128i *)(ssrcs + pos)); + __m128i mm2 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 4u)); + mm1 = _mm_cmpeq_epi32(mm1, mm_ssrc); + mm2 = _mm_cmpeq_epi32(mm2, mm_ssrc); + mm1 = _mm_packs_epi32(mm1, mm2); + + uint32_t mask = _mm_movemask_epi8(mm1); + if (mask) { +#if defined(_MSC_VER) + unsigned long bit_pos; + _BitScanForward(&bit_pos, mask); + pos += bit_pos / 2u; +#else + pos += __builtin_ctz(mask) / 2u; +#endif + goto done; } - stream = stream->next; + + pos += 8u; + } + +done: + return pos; +#else + /* walk down list until ssrc is found */ + uint32_t pos = 0u, n = list->size; + for (; pos < n; ++pos) { + if (list->ssrcs[pos] == ssrc) + break; } + return pos; +#endif +} + +static inline srtp_stream_t srtp_stream_list_get_at(srtp_stream_list_t list, + uint32_t pos) +{ + return list->streams[pos]; +} + +srtp_stream_t srtp_stream_list_get(srtp_stream_list_t list, uint32_t ssrc) +{ + uint32_t pos = srtp_stream_list_find(list, ssrc); + if (pos < list->size) + return list->streams[pos]; + /* we haven't found our ssrc, so return a null */ return NULL; } -void srtp_stream_list_remove(srtp_stream_list_t list, - srtp_stream_t stream_to_remove) +static void srtp_stream_list_remove_at(srtp_stream_list_t list, uint32_t pos) { - (void)list; + uint32_t tail_size, last_pos; - stream_to_remove->prev->next = stream_to_remove->next; - if (stream_to_remove->next != NULL) { - stream_to_remove->next->prev = stream_to_remove->prev; + last_pos = --list->size; + tail_size = last_pos - pos; + if (tail_size > 0u) { + memmove(list->streams + pos, list->streams + pos + 1, + (size_t)tail_size * sizeof(*list->streams)); + memmove(list->ssrcs + pos, list->ssrcs + pos + 1, + (size_t)tail_size * sizeof(*list->ssrcs)); } + + list->streams[last_pos] = NULL; + list->ssrcs[last_pos] = 0u; +} + +void srtp_stream_list_remove(srtp_stream_list_t list, + srtp_stream_t stream_to_remove) +{ + uint32_t pos = srtp_stream_list_find(list, stream_to_remove->ssrc); + if (pos < list->size) + srtp_stream_list_remove_at(list, pos); } void srtp_stream_list_for_each(srtp_stream_list_t list, int (*callback)(srtp_stream_t, void *), void *data) { - srtp_stream_t stream = list->data.next; - while (stream != NULL) { - srtp_stream_t tmp = stream; - stream = stream->next; - if (callback(tmp, data)) + uint32_t size = list->size; + for (uint32_t i = 0u; i < size;) { + if (callback(list->streams[i], data)) break; + + /* check if the callback removed the current element */ + if (size == list->size) + ++i; + else + size = list->size; } }