Skip to content

Commit de9e9b1

Browse files
Nicoshevfacebook-github-bot
authored andcommitted
Partly vectorize CompactProtocol list read (facebook#9606)
Summary: Partly vectorize CompactProtocol's list reading, mainly on aarch64. Performance gains varies by type: before: CompactProtocol_read_SmallListInt 36.10ns 27.70M CompactProtocol_read_BigListByte 18.32us 54.57K 10005 CompactProtocol_read_BigListShort 27.57us 36.27K 27489 CompactProtocol_read_BigListInt 22.74us 43.97K 49370 CompactProtocol_read_BigListBigInt 25.26us 39.59K 49696 CompactProtocol_read_BigListFloat 18.62us 53.69K 40005 CompactProtocol_read_BigListDouble 18.81us 53.16K 80005 after: CompactProtocol_read_SmallListInt 27.07ns 36.94M 52 CompactProtocol_read_BigListByte 185.48ns 5.39M 10005 CompactProtocol_read_BigListShort 5.97us 167.42K 27489 CompactProtocol_read_BigListInt 8.67us 115.37K 49370 CompactProtocol_read_BigListBigInt 13.01us 76.87K 49696 CompactProtocol_read_BigListFloat 827.75ns 1.21M 40005 CompactProtocol_read_BigListDouble 1.67us 600.49K 80005 Differential Revision: D73063243
1 parent 3092957 commit de9e9b1

File tree

3 files changed

+228
-1
lines changed

3 files changed

+228
-1
lines changed

third-party/thrift/src/thrift/lib/cpp2/protocol/CompactProtocol.cpp

Lines changed: 221 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,226 @@ size_t CompactProtocolWriter::writeArithmeticVector<double>(
447447
out_, inputPtr, numElements);
448448
}
449449

450-
#endif // FOLLY_AARCH64
450+
// Decodes compacted zigzag varints in a vectorized manner
451+
template <class Cursor, typename T>
452+
static inline void readEncodedArithmeticVectorSIMD(
453+
Cursor& c, T* outputPtr, size_t numElements) {
454+
constexpr size_t simdWidth = sizeof(uint8x16_t) / sizeof(T);
455+
size_t len = c.length();
456+
size_t i = 0;
457+
size_t loopBound = numElements - (numElements % simdWidth);
458+
while (i < numElements && len >= sizeof(T)) {
459+
const uint8_t* inPtr = c.data();
460+
const uint8_t* endSimd =
461+
inPtr + len - util::detail::kVarintMaxBytes<T> * simdWidth;
462+
const uint8_t* endScalar = inPtr + len - util::detail::kVarintMaxBytes<T>;
463+
const uint8_t* start = inPtr;
464+
for (; i < loopBound && inPtr <= endSimd; i += simdWidth) {
465+
if constexpr (sizeof(T) == 4) {
466+
uint32x4_t vec;
467+
T value;
468+
inPtr +=
469+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
470+
vec[0] = value;
471+
inPtr +=
472+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
473+
vec[1] = value;
474+
inPtr +=
475+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
476+
vec[2] = value;
477+
inPtr +=
478+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
479+
vec[3] = value;
480+
uint32x4_t vecBit = vshlq_n_u32(vec, 31);
481+
vecBit = vreinterpretq_u32_s32(
482+
vshrq_n_s32(vreinterpretq_s32_u32(vecBit), 30));
483+
vec = svget_neonq_u32(svxar_n_u32(
484+
svset_neonq_u32(svundef_u32(), vec),
485+
svset_neonq_u32(svundef_u32(), vecBit),
486+
1));
487+
vst1q_u32(reinterpret_cast<uint32_t*>(outputPtr + i), vec);
488+
} else if constexpr (sizeof(T) == 2) {
489+
uint16x8_t vec;
490+
T value;
491+
inPtr +=
492+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
493+
vec[0] = value;
494+
inPtr +=
495+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
496+
vec[1] = value;
497+
inPtr +=
498+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
499+
vec[2] = value;
500+
inPtr +=
501+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
502+
vec[3] = value;
503+
inPtr +=
504+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
505+
vec[4] = value;
506+
inPtr +=
507+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
508+
vec[5] = value;
509+
inPtr +=
510+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
511+
vec[6] = value;
512+
inPtr +=
513+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
514+
vec[7] = value;
515+
uint16x8_t vecBit = vshlq_n_u16(vec, 15);
516+
vecBit = vreinterpretq_u16_s16(
517+
vshrq_n_s16(vreinterpretq_s16_u16(vecBit), 14));
518+
vec = svget_neonq_u16(svxar_n_u16(
519+
svset_neonq_u16(svundef_u16(), vec),
520+
svset_neonq_u16(svundef_u16(), vecBit),
521+
1));
522+
vst1q_u16(reinterpret_cast<uint16_t*>(outputPtr + i), vec);
523+
}
524+
}
525+
for (; i < numElements && inPtr <= endScalar; ++i) {
526+
T value;
527+
inPtr += util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
528+
outputPtr[i] = util::detail::zigzagToSignedInt(value);
529+
}
530+
size_t consumed = inPtr - start;
531+
c.skip(consumed);
532+
len -= consumed;
533+
size_t trailingLoopBound = std::min(numElements, i + len);
534+
for (; i < trailingLoopBound; ++i) {
535+
// Need to finish consuming current input buffer
536+
T value;
537+
util::detail::readVarintSlow(c, value);
538+
outputPtr[i] = util::detail::zigzagToSignedInt(value);
539+
}
540+
}
541+
}
542+
543+
#endif // FOLLY_ARM_FEATURE_NEON_SVE_BRIDGE
544+
545+
// Function used with data types that are decoded from compacted zigzag
546+
template <class Cursor, typename T>
547+
static inline void readEncodedArithmeticVector(
548+
Cursor& c, T* outputPtr, size_t numElements) {
549+
size_t i = 0;
550+
size_t numElementsMod = numElements & 1;
551+
size_t loopBound = numElements - numElementsMod;
552+
while (i < numElements) {
553+
const uint8_t* inPtr = c.data();
554+
const uint8_t* start = inPtr;
555+
size_t len = c.length();
556+
const uint8_t* endVec = inPtr + len - util::detail::kVarintMaxBytes<T> * 2;
557+
for (; i < loopBound && inPtr <= endVec; i += 2) {
558+
T valueA;
559+
T valueB;
560+
inPtr += util::detail::readVarintMediumSlowUnrolled(valueA, inPtr);
561+
inPtr += util::detail::readVarintMediumSlowUnrolled(valueB, inPtr);
562+
valueA = util::detail::zigzagToSignedInt(valueA);
563+
valueB = util::detail::zigzagToSignedInt(valueB);
564+
outputPtr[i] = valueA;
565+
outputPtr[i + 1] = valueB;
566+
}
567+
size_t consumed = inPtr - start;
568+
c.skip(consumed);
569+
len -= consumed;
570+
size_t trailingLoopBound = std::min(numElements, i + len);
571+
while (i < trailingLoopBound) {
572+
// Need to finish consuming current input buffer
573+
T value;
574+
util::detail::readVarintSlow(c, value);
575+
outputPtr[i] = util::detail::zigzagToSignedInt(value);
576+
++i;
577+
}
578+
}
579+
}
580+
581+
#if !FOLLY_ARM_FEATURE_NEON_SVE_BRIDGE
582+
// Decodes compacted zigzag varints in a vectorized manner
583+
template <class Cursor, typename T>
584+
static inline void readEncodedArithmeticVectorSIMD(
585+
Cursor& c, T* outputPtr, size_t numElements) {
586+
return readEncodedArithmeticVector<Cursor, T>(c, outputPtr, numElements);
587+
}
588+
#endif // !FOLLY_ARM_FEATURE_NEON_SVE_BRIDGE
589+
590+
// Function used with data types that are just received as BE/LE bytes
591+
template <class Cursor, typename T, bool BE>
592+
static inline void readUnencodedArithmeticVector(
593+
Cursor& c, T* outputPtr, size_t numElements) {
594+
size_t i = 0;
595+
size_t len = c.length();
596+
while (i < numElements && len >= sizeof(T)) {
597+
const uint8_t* inPtr = c.data();
598+
size_t loopBound = std::min(numElements, i + len / sizeof(T));
599+
size_t j = 0;
600+
for (; i < loopBound; ++i, ++j) {
601+
T value = BE ? folly::Endian::big<T>(
602+
folly::loadUnaligned<T>(inPtr + j * sizeof(T)))
603+
: folly::loadUnaligned<T>(inPtr + j * sizeof(T));
604+
outputPtr[i] = value;
605+
}
606+
c.skip(j * sizeof(T));
607+
len = c.length();
608+
}
609+
}
610+
611+
template <>
612+
void CompactProtocolReader::readArithmeticVector<int64_t>(
613+
int64_t* outputPtr, size_t numElements) {
614+
return readEncodedArithmeticVector<Cursor, int64_t>(
615+
in_, outputPtr, numElements);
616+
}
617+
template <>
618+
void CompactProtocolReader::readArithmeticVector<uint64_t>(
619+
uint64_t* outputPtr, size_t numElements) {
620+
return readEncodedArithmeticVector<Cursor, uint64_t>(
621+
in_, outputPtr, numElements);
622+
}
623+
template <>
624+
void CompactProtocolReader::readArithmeticVector<int32_t>(
625+
int32_t* outputPtr, size_t numElements) {
626+
return readEncodedArithmeticVectorSIMD<Cursor, int32_t>(
627+
in_, outputPtr, numElements);
628+
}
629+
template <>
630+
void CompactProtocolReader::readArithmeticVector<uint32_t>(
631+
uint32_t* outputPtr, size_t numElements) {
632+
return readEncodedArithmeticVectorSIMD<Cursor, uint32_t>(
633+
in_, outputPtr, numElements);
634+
}
635+
template <>
636+
void CompactProtocolReader::readArithmeticVector<int16_t>(
637+
int16_t* outputPtr, size_t numElements) {
638+
return readEncodedArithmeticVectorSIMD<Cursor, int16_t>(
639+
in_, outputPtr, numElements);
640+
}
641+
template <>
642+
void CompactProtocolReader::readArithmeticVector<uint16_t>(
643+
uint16_t* outputPtr, size_t numElements) {
644+
return readEncodedArithmeticVectorSIMD<Cursor, uint16_t>(
645+
in_, outputPtr, numElements);
646+
}
647+
template <>
648+
void CompactProtocolReader::readArithmeticVector<int8_t>(
649+
int8_t* outputPtr, size_t numElements) {
650+
return readUnencodedArithmeticVector<Cursor, int8_t, false>(
651+
in_, outputPtr, numElements);
652+
}
653+
template <>
654+
void CompactProtocolReader::readArithmeticVector<uint8_t>(
655+
uint8_t* outputPtr, size_t numElements) {
656+
return readUnencodedArithmeticVector<Cursor, uint8_t, false>(
657+
in_, outputPtr, numElements);
658+
}
659+
template <>
660+
void CompactProtocolReader::readArithmeticVector<float>(
661+
float* outputPtr, size_t numElements) {
662+
return readUnencodedArithmeticVector<Cursor, float, true>(
663+
in_, outputPtr, numElements);
664+
}
665+
template <>
666+
void CompactProtocolReader::readArithmeticVector<double>(
667+
double* outputPtr, size_t numElements) {
668+
return readUnencodedArithmeticVector<Cursor, double, true>(
669+
in_, outputPtr, numElements);
670+
}
451671

452672
} // namespace apache::thrift

third-party/thrift/src/thrift/lib/cpp2/protocol/CompactProtocol.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ class CompactProtocolReader : public detail::ProtocolBase {
261261

262262
static constexpr bool kHasDeferredRead() { return true; }
263263

264+
static constexpr bool kSupportsArithmeticVectors() { return true; }
265+
264266
void setStringSizeLimit(int32_t string_limit) {
265267
string_limit_ = string_limit;
266268
}
@@ -302,6 +304,8 @@ class CompactProtocolReader : public detail::ProtocolBase {
302304
void readI64(int64_t& i64);
303305
void readDouble(double& dub);
304306
void readFloat(float& flt);
307+
template <typename T>
308+
void readArithmeticVector(T* outputPtr, size_t numElements);
305309
template <typename StrType>
306310
void readString(StrType& str);
307311
template <typename StrType>

third-party/thrift/src/thrift/lib/cpp2/protocol/CompactV1Protocol.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class CompactV1ProtocolReader : protected CompactProtocolReader {
137137
using CompactProtocolReader::peekList;
138138
using CompactProtocolReader::peekMap;
139139
using CompactProtocolReader::peekSet;
140+
using CompactProtocolReader::readArithmeticVector;
140141
using CompactProtocolReader::readBinary;
141142
using CompactProtocolReader::readFloat;
142143
using CompactProtocolReader::readString;
@@ -145,6 +146,8 @@ class CompactV1ProtocolReader : protected CompactProtocolReader {
145146

146147
using CompactProtocolReader::getCursor;
147148
using CompactProtocolReader::getCursorPosition;
149+
150+
static constexpr bool kSupportsArithmeticVectors() { return true; }
148151
};
149152

150153
} // namespace apache::thrift

0 commit comments

Comments
 (0)