Skip to content

Commit cf10c0b

Browse files
Nicoshevfacebook-github-bot
authored andcommitted
Partly vectorize CompactProtocol list read (facebook#9606)
Summary: Pull Request resolved: facebook#9606 Partly vectorize CompactProtocol's list reading, mainly on aarch64. Performance gains varies by type: before: CompactProtocol_read_SmallListInt 36.10ns 27.70M CompactProtocol_read_BigListByte 18.32us 54.57K 10005 CompactProtocol_read_BigListShort 27.57us 36.27K 27489 CompactProtocol_read_BigListInt 22.74us 43.97K 49370 CompactProtocol_read_BigListBigInt 25.26us 39.59K 49696 CompactProtocol_read_BigListFloat 18.62us 53.69K 40005 CompactProtocol_read_BigListDouble 18.81us 53.16K 80005 after: CompactProtocol_read_SmallListInt 27.07ns 36.94M 52 CompactProtocol_read_BigListByte 185.48ns 5.39M 10005 CompactProtocol_read_BigListShort 6.01us 166.50K 27489 CompactProtocol_read_BigListInt 8.67us 115.37K 49370 CompactProtocol_read_BigListBigInt 11.33us 88.26K 49696 CompactProtocol_read_BigListFloat 827.75ns 1.21M 40005 CompactProtocol_read_BigListDouble 1.67us 600.49K 80005 Differential Revision: D73063243
1 parent 18a5273 commit cf10c0b

File tree

4 files changed

+264
-2
lines changed

4 files changed

+264
-2
lines changed

third-party/thrift/src/thrift/lib/cpp2/protocol/CompactProtocol.cpp

Lines changed: 253 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,258 @@ size_t CompactProtocolWriter::writeArithmeticVector<double>(
462462
out_, inputPtr, numElements);
463463
}
464464

465-
#endif // FOLLY_AARCH64
465+
// Decodes compacted zigzag varints in a vectorized manner
466+
template <class Cursor, typename T>
467+
static inline void readEncodedArithmeticVectorSIMD(
468+
Cursor& c, T* outputPtr, size_t numElements) {
469+
constexpr size_t simdWidth = sizeof(uint8x16_t) / sizeof(T);
470+
size_t i = 0;
471+
size_t loopBound = numElements - (numElements % simdWidth);
472+
while (i < numElements) {
473+
const uint8_t* inPtr = c.data();
474+
size_t len = c.length();
475+
const uint8_t* endSimd =
476+
inPtr + len - util::detail::kVarintMaxBytes<T> * simdWidth;
477+
const uint8_t* endScalar = inPtr + len - util::detail::kVarintMaxBytes<T>;
478+
const uint8_t* start = inPtr;
479+
for (; i < loopBound && inPtr <= endSimd; i += simdWidth) {
480+
if constexpr (sizeof(T) == 4) {
481+
uint32x4_t vec;
482+
T value;
483+
inPtr +=
484+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
485+
vec[0] = value;
486+
inPtr +=
487+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
488+
vec[1] = value;
489+
inPtr +=
490+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
491+
vec[2] = value;
492+
inPtr +=
493+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
494+
vec[3] = value;
495+
uint32x4_t vecBit = vshlq_n_u32(vec, 31);
496+
vecBit = vreinterpretq_u32_s32(
497+
vshrq_n_s32(vreinterpretq_s32_u32(vecBit), 30));
498+
vec = svget_neonq_u32(svxar_n_u32(
499+
svset_neonq_u32(svundef_u32(), vec),
500+
svset_neonq_u32(svundef_u32(), vecBit),
501+
1));
502+
vst1q_u32(reinterpret_cast<uint32_t*>(outputPtr + i), vec);
503+
} else if constexpr (sizeof(T) == 2) {
504+
uint16x8_t vec;
505+
T value;
506+
inPtr +=
507+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
508+
vec[0] = value;
509+
inPtr +=
510+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
511+
vec[1] = value;
512+
inPtr +=
513+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
514+
vec[2] = value;
515+
inPtr +=
516+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
517+
vec[3] = value;
518+
inPtr +=
519+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
520+
vec[4] = value;
521+
inPtr +=
522+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
523+
vec[5] = value;
524+
inPtr +=
525+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
526+
vec[6] = value;
527+
inPtr +=
528+
util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
529+
vec[7] = value;
530+
uint16x8_t vecBit = vshlq_n_u16(vec, 15);
531+
vecBit = vreinterpretq_u16_s16(
532+
vshrq_n_s16(vreinterpretq_s16_u16(vecBit), 14));
533+
vec = svget_neonq_u16(svxar_n_u16(
534+
svset_neonq_u16(svundef_u16(), vec),
535+
svset_neonq_u16(svundef_u16(), vecBit),
536+
1));
537+
vst1q_u16(reinterpret_cast<uint16_t*>(outputPtr + i), vec);
538+
}
539+
}
540+
for (; i < numElements && inPtr <= endScalar; ++i) {
541+
int32_t value;
542+
inPtr += util::detail::readVarintMediumSlowUnrolledAarch64(value, inPtr);
543+
outputPtr[i] = (T)util::detail::zigzagToSignedInt(value);
544+
}
545+
size_t consumed = inPtr - start;
546+
c.skip(consumed);
547+
len -= consumed;
548+
size_t trailingLoopBound = std::min(numElements, i + len + 1);
549+
for (; i < trailingLoopBound; ++i) {
550+
// Need to finish consuming current input buffer
551+
int32_t value;
552+
util::detail::readVarintSlow(c, value);
553+
outputPtr[i] = (T)util::detail::zigzagToSignedInt(value);
554+
}
555+
}
556+
}
557+
558+
#endif // FOLLY_ARM_FEATURE_NEON_SVE_BRIDGE
559+
560+
// Function used with data types that are decoded from compacted zigzag
561+
template <class Cursor, typename T>
562+
static inline void readEncodedArithmeticVector(
563+
Cursor& c, T* outputPtr, size_t numElements) {
564+
size_t i = 0;
565+
size_t numElementsMod = numElements & 1;
566+
size_t loopBound = numElements - numElementsMod;
567+
while (i < numElements) {
568+
const uint8_t* inPtr = c.data();
569+
const uint8_t* start = inPtr;
570+
size_t len = c.length();
571+
constexpr size_t kMaxVarintBytes = sizeof(T) == 2
572+
? util::detail::kVarintMaxBytes<int32_t>
573+
: util::detail::kVarintMaxBytes<T>;
574+
const uint8_t* endVec = inPtr + len - kMaxVarintBytes * 2;
575+
for (; i < loopBound && inPtr <= endVec; i += 2) {
576+
if constexpr (sizeof(T) == 2) {
577+
int32_t valueA;
578+
int32_t valueB;
579+
inPtr += util::detail::readVarintMediumSlowUnrolled(valueA, inPtr);
580+
inPtr += util::detail::readVarintMediumSlowUnrolled(valueB, inPtr);
581+
valueA = util::detail::zigzagToSignedInt(valueA);
582+
valueB = util::detail::zigzagToSignedInt(valueB);
583+
outputPtr[i] = (T)valueA;
584+
outputPtr[i + 1] = (T)valueB;
585+
} else {
586+
T valueA;
587+
T valueB;
588+
inPtr += util::detail::readVarintMediumSlowUnrolled(valueA, inPtr);
589+
inPtr += util::detail::readVarintMediumSlowUnrolled(valueB, inPtr);
590+
valueA = util::detail::zigzagToSignedInt(valueA);
591+
valueB = util::detail::zigzagToSignedInt(valueB);
592+
outputPtr[i] = valueA;
593+
outputPtr[i + 1] = valueB;
594+
}
595+
}
596+
size_t consumed = inPtr - start;
597+
c.skipNoAdvance(consumed);
598+
len -= consumed;
599+
size_t trailingLoopBound = std::min(numElements, i + len + 1);
600+
while (i < trailingLoopBound) {
601+
if constexpr (sizeof(T) == 2) {
602+
// Need to finish consuming current input buffer
603+
int32_t value;
604+
util::detail::readVarintSlow(c, value);
605+
outputPtr[i] = (T)util::detail::zigzagToSignedInt(value);
606+
} else {
607+
// Need to finish consuming current input buffer
608+
T value;
609+
util::detail::readVarintSlow(c, value);
610+
outputPtr[i] = util::detail::zigzagToSignedInt(value);
611+
}
612+
++i;
613+
}
614+
}
615+
}
616+
617+
#if !FOLLY_ARM_FEATURE_NEON_SVE_BRIDGE
618+
// Decodes compacted zigzag varints in a vectorized manner
619+
template <class Cursor, typename T>
620+
static inline void readEncodedArithmeticVectorSIMD(
621+
Cursor& c, T* outputPtr, size_t numElements) {
622+
return readEncodedArithmeticVector<Cursor, T>(c, outputPtr, numElements);
623+
}
624+
#endif // !FOLLY_ARM_FEATURE_NEON_SVE_BRIDGE
625+
626+
// Function used with data types that are just received as BE/LE bytes
627+
template <class Cursor, typename T, bool BE>
628+
static inline void readUnencodedArithmeticVector(
629+
Cursor& c, T* outputPtr, size_t numElements) {
630+
size_t i = 0;
631+
while (i < numElements) {
632+
const uint8_t* inPtr = c.data();
633+
size_t len = c.length();
634+
size_t loopBound = std::min(numElements, i + len / sizeof(T));
635+
size_t j = 0;
636+
for (; i < loopBound; ++i, ++j) {
637+
T value = BE ? folly::Endian::big<T>(
638+
folly::loadUnaligned<T>(inPtr + j * sizeof(T)))
639+
: folly::loadUnaligned<T>(inPtr + j * sizeof(T));
640+
outputPtr[i] = value;
641+
}
642+
c.skipNoAdvance(j * sizeof(T));
643+
if (i < numElements) {
644+
if constexpr (sizeof(T) == 8) {
645+
uint64_t bits = c.template readBE<int64_t>();
646+
outputPtr[i] = folly::bit_cast<double>(bits);
647+
} else if constexpr (sizeof(T) == 4) {
648+
uint32_t bits = c.template readBE<int32_t>();
649+
outputPtr[i] = folly::bit_cast<float>(bits);
650+
} else {
651+
outputPtr[i] = c.template read<int8_t>();
652+
}
653+
++i;
654+
}
655+
}
656+
}
657+
658+
template <>
659+
void CompactProtocolReader::readArithmeticVector<int64_t>(
660+
int64_t* outputPtr, size_t numElements) {
661+
return readEncodedArithmeticVector<Cursor, int64_t>(
662+
in_, outputPtr, numElements);
663+
}
664+
template <>
665+
void CompactProtocolReader::readArithmeticVector<uint64_t>(
666+
uint64_t* outputPtr, size_t numElements) {
667+
return readEncodedArithmeticVector<Cursor, uint64_t>(
668+
in_, outputPtr, numElements);
669+
}
670+
template <>
671+
void CompactProtocolReader::readArithmeticVector<int32_t>(
672+
int32_t* outputPtr, size_t numElements) {
673+
return readEncodedArithmeticVectorSIMD<Cursor, int32_t>(
674+
in_, outputPtr, numElements);
675+
}
676+
template <>
677+
void CompactProtocolReader::readArithmeticVector<uint32_t>(
678+
uint32_t* outputPtr, size_t numElements) {
679+
return readEncodedArithmeticVectorSIMD<Cursor, uint32_t>(
680+
in_, outputPtr, numElements);
681+
}
682+
template <>
683+
void CompactProtocolReader::readArithmeticVector<int16_t>(
684+
int16_t* outputPtr, size_t numElements) {
685+
return readEncodedArithmeticVectorSIMD<Cursor, int16_t>(
686+
in_, outputPtr, numElements);
687+
}
688+
template <>
689+
void CompactProtocolReader::readArithmeticVector<uint16_t>(
690+
uint16_t* outputPtr, size_t numElements) {
691+
return readEncodedArithmeticVectorSIMD<Cursor, uint16_t>(
692+
in_, outputPtr, numElements);
693+
}
694+
template <>
695+
void CompactProtocolReader::readArithmeticVector<int8_t>(
696+
int8_t* outputPtr, size_t numElements) {
697+
return readUnencodedArithmeticVector<Cursor, int8_t, false>(
698+
in_, outputPtr, numElements);
699+
}
700+
template <>
701+
void CompactProtocolReader::readArithmeticVector<uint8_t>(
702+
uint8_t* outputPtr, size_t numElements) {
703+
return readUnencodedArithmeticVector<Cursor, uint8_t, false>(
704+
in_, outputPtr, numElements);
705+
}
706+
template <>
707+
void CompactProtocolReader::readArithmeticVector<float>(
708+
float* outputPtr, size_t numElements) {
709+
return readUnencodedArithmeticVector<Cursor, float, true>(
710+
in_, outputPtr, numElements);
711+
}
712+
template <>
713+
void CompactProtocolReader::readArithmeticVector<double>(
714+
double* outputPtr, size_t numElements) {
715+
return readUnencodedArithmeticVector<Cursor, double, true>(
716+
in_, outputPtr, numElements);
717+
}
466718

467719
} // namespace apache::thrift

third-party/thrift/src/thrift/lib/cpp2/protocol/CompactProtocol.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ class CompactProtocolReader : public detail::ProtocolBase {
253253

254254
static constexpr bool kHasDeferredRead() { return true; }
255255

256+
static constexpr bool kSupportsArithmeticVectors() { return true; }
257+
256258
void setStringSizeLimit(int32_t string_limit) {
257259
string_limit_ = string_limit;
258260
}
@@ -294,6 +296,8 @@ class CompactProtocolReader : public detail::ProtocolBase {
294296
void readI64(int64_t& i64);
295297
void readDouble(double& dub);
296298
void readFloat(float& flt);
299+
template <typename T>
300+
void readArithmeticVector(T* outputPtr, size_t numElements);
297301
template <typename StrType>
298302
void readString(StrType& str);
299303
template <typename StrType>

third-party/thrift/src/thrift/lib/cpp2/protocol/CompactV1Protocol.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ class CompactV1ProtocolReader : protected CompactProtocolReader {
139139

140140
using CompactProtocolReader::getCursor;
141141
using CompactProtocolReader::getCursorPosition;
142+
143+
static constexpr bool kSupportsArithmeticVectors() { return false; }
142144
};
143145

144146
} // namespace apache::thrift

third-party/thrift/src/thrift/lib/cpp2/protocol/test/ProtocolTest.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,11 @@ void runBigListTest(
225225
} else {
226226
prot_method_integral::read(r, outList);
227227
}
228-
ASSERT_EQ(intList, outList);
228+
ASSERT_EQ(intList.size(), outList.size());
229+
size_t len = std::min(intList.size(), outList.size());
230+
for (size_t j = 0; j < len; ++j) {
231+
ASSERT_EQ(intList[j], outList[j]);
232+
}
229233
}
230234
}
231235
}

0 commit comments

Comments
 (0)