Skip to content

Commit c5f1613

Browse files
authored
Receive Path API Improvements (#101)
1 parent 0345e9a commit c5f1613

File tree

7 files changed

+244
-74
lines changed

7 files changed

+244
-74
lines changed

lib/msh3.cpp

Lines changed: 123 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,28 @@ MsH3RequestClose(
170170
delete (MsH3BiDirStream*)Handle;
171171
}
172172

173+
extern "C"
174+
void
175+
MSH3_CALL
176+
MsH3RequestCompleteReceive(
177+
MSH3_REQUEST* Handle,
178+
uint32_t Length
179+
)
180+
{
181+
((MsH3BiDirStream*)Handle)->CompleteReceive(Length);
182+
}
183+
184+
extern "C"
185+
void
186+
MSH3_CALL
187+
MsH3RequestSetReceiveEnabled(
188+
MSH3_REQUEST* Handle,
189+
bool Enabled
190+
)
191+
{
192+
((MsH3BiDirStream*)Handle)->SetReceiveEnabled(Enabled);
193+
}
194+
173195
extern "C"
174196
bool
175197
MSH3_CALL
@@ -801,10 +823,7 @@ MsH3BiDirStream::MsQuicCallback(
801823
}
802824
break;
803825
case QUIC_STREAM_EVENT_RECEIVE:
804-
for (uint32_t i = 0; i < Event->RECEIVE.BufferCount; ++i) {
805-
Receive(Event->RECEIVE.Buffers + i);
806-
}
807-
break;
826+
return Receive(Event);
808827
case QUIC_STREAM_EVENT_SEND_COMPLETE:
809828
if (Event->SEND_COMPLETE.ClientContext) {
810829
auto AppSend = (MsH3AppSend*)Event->SEND_COMPLETE.ClientContext;
@@ -830,71 +849,118 @@ MsH3BiDirStream::MsQuicCallback(
830849
return QUIC_STATUS_SUCCESS;
831850
}
832851

833-
void
852+
QUIC_STATUS
834853
MsH3BiDirStream::Receive(
835-
_In_ const QUIC_BUFFER* Buffer
854+
_Inout_ QUIC_STREAM_EVENT* Event
836855
)
837856
{
838-
uint32_t Offset = 0;
839-
840-
do {
841-
if (CurFrameLengthLeft == 0) {
842-
if (BufferedHeadersLength == 0) {
843-
if (!MsH3VarIntDecode(Buffer->Length, Buffer->Buffer, &Offset, &CurFrameType) ||
844-
!MsH3VarIntDecode(Buffer->Length, Buffer->Buffer, &Offset, &CurFrameLength)) {
845-
BufferedHeadersLength = Buffer->Length - Offset;
846-
memcpy(BufferedHeaders, Buffer->Buffer + Offset, BufferedHeadersLength);
847-
return;
848-
}
849-
} else {
850-
uint32_t ToCopy = sizeof(BufferedHeaders) - BufferedHeadersLength;
851-
if (ToCopy > Buffer->Length) ToCopy = Buffer->Length;
852-
memcpy(BufferedHeaders + BufferedHeadersLength, Buffer->Buffer, ToCopy);
853-
if (!MsH3VarIntDecode(BufferedHeadersLength+ToCopy, BufferedHeaders, &Offset, &CurFrameType) ||
854-
!MsH3VarIntDecode(BufferedHeadersLength+ToCopy, BufferedHeaders, &Offset, &CurFrameLength)) {
855-
BufferedHeadersLength += ToCopy;
856-
return;
857+
for (uint32_t i = 0; i < Event->RECEIVE.BufferCount; ++i) {
858+
const QUIC_BUFFER* Buffer = Event->RECEIVE.Buffers + i;
859+
do {
860+
if (CurFrameLengthLeft == 0) { // Not in the middle of reading frame payload
861+
if (BufferedHeadersLength == 0) { // No partial frame header bufferred
862+
if (!MsH3VarIntDecode(Buffer->Length, Buffer->Buffer, &CurRecvOffset, &CurFrameType) ||
863+
!MsH3VarIntDecode(Buffer->Length, Buffer->Buffer, &CurRecvOffset, &CurFrameLength)) {
864+
BufferedHeadersLength = Buffer->Length - CurRecvOffset;
865+
memcpy(BufferedHeaders, Buffer->Buffer + CurRecvOffset, BufferedHeadersLength);
866+
break;
867+
}
868+
} else { // Partial frame header bufferred already
869+
uint32_t ToCopy = sizeof(BufferedHeaders) - BufferedHeadersLength;
870+
if (ToCopy > Buffer->Length) ToCopy = Buffer->Length;
871+
memcpy(BufferedHeaders + BufferedHeadersLength, Buffer->Buffer, ToCopy);
872+
if (!MsH3VarIntDecode(BufferedHeadersLength+ToCopy, BufferedHeaders, &CurRecvOffset, &CurFrameType) ||
873+
!MsH3VarIntDecode(BufferedHeadersLength+ToCopy, BufferedHeaders, &CurRecvOffset, &CurFrameLength)) {
874+
BufferedHeadersLength += ToCopy;
875+
break;
876+
}
877+
CurRecvOffset -= BufferedHeadersLength;
878+
BufferedHeadersLength = 0;
857879
}
858-
Offset -= BufferedHeadersLength;
859-
BufferedHeadersLength = 0;
880+
CurFrameLengthLeft = CurFrameLength;
860881
}
861-
CurFrameLengthLeft = CurFrameLength;
862-
}
863882

864-
uint32_t AvailFrameLength;
865-
if (Offset + CurFrameLengthLeft > (uint64_t)Buffer->Length) {
866-
AvailFrameLength = Buffer->Length - Offset; // Rest of the buffer
867-
} else {
868-
AvailFrameLength = (uint32_t)CurFrameLengthLeft;
869-
}
883+
uint32_t AvailFrameLength;
884+
if (CurRecvOffset + CurFrameLengthLeft > (uint64_t)Buffer->Length) {
885+
AvailFrameLength = Buffer->Length - CurRecvOffset; // Rest of the buffer
886+
} else {
887+
AvailFrameLength = (uint32_t)CurFrameLengthLeft;
888+
}
870889

871-
if (CurFrameType == H3FrameData) {
872-
Callbacks.DataReceived((MSH3_REQUEST*)this, Context, AvailFrameLength, Buffer->Buffer + Offset);
873-
} else if (CurFrameType == H3FrameHeaders) {
874-
const uint8_t* Frame = Buffer->Buffer + Offset;
875-
if (CurFrameLengthLeft == CurFrameLength) {
876-
auto rhs =
877-
lsqpack_dec_header_in(
878-
&H3.Decoder, this, ID(), (size_t)CurFrameLength, &Frame,
879-
AvailFrameLength, nullptr, nullptr);
880-
if (rhs != LQRHS_DONE && rhs != LQRHS_NEED) {
881-
printf("lsqpack_dec_header_in failure res=%u\n", rhs);
890+
if (CurFrameType == H3FrameData) {
891+
uint32_t AppReceiveLength = AvailFrameLength;
892+
ReceivePending = true;
893+
if (Callbacks.DataReceived((MSH3_REQUEST*)this, Context, &AppReceiveLength, Buffer->Buffer + CurRecvOffset)) {
894+
ReceivePending = false; // Not pending receive
895+
if (AppReceiveLength < AvailFrameLength) { // Partial receive case
896+
CurFrameLengthLeft -= AppReceiveLength;
897+
Event->RECEIVE.TotalBufferLength = CurRecvCompleteLength + CurRecvOffset + AppReceiveLength;
898+
CurRecvCompleteLength = 0;
899+
CurRecvOffset = 0;
900+
return QUIC_STATUS_SUCCESS;
901+
}
902+
} else { // Receive pending (but may have been completed via API call already)
903+
if (!ReceivePending) {
904+
// TODO - Support continuing this receive since it was completed via the API call
905+
}
906+
return QUIC_STATUS_PENDING;
882907
}
883-
} else { // Continued from a previous partial read
884-
auto rhs =
885-
lsqpack_dec_header_read(
886-
&H3.Decoder, this, &Frame, AvailFrameLength, nullptr,
887-
nullptr);
888-
if (rhs != LQRHS_DONE && rhs != LQRHS_NEED) {
889-
printf("lsqpack_dec_header_read failure res=%u\n", rhs);
908+
} else if (CurFrameType == H3FrameHeaders) {
909+
const uint8_t* Frame = Buffer->Buffer + CurRecvOffset;
910+
if (CurFrameLengthLeft == CurFrameLength) {
911+
auto rhs =
912+
lsqpack_dec_header_in(
913+
&H3.Decoder, this, ID(), (size_t)CurFrameLength, &Frame,
914+
AvailFrameLength, nullptr, nullptr);
915+
if (rhs != LQRHS_DONE && rhs != LQRHS_NEED) {
916+
printf("lsqpack_dec_header_in failure res=%u\n", rhs);
917+
}
918+
} else { // Continued from a previous partial read
919+
auto rhs =
920+
lsqpack_dec_header_read(
921+
&H3.Decoder, this, &Frame, AvailFrameLength, nullptr,
922+
nullptr);
923+
if (rhs != LQRHS_DONE && rhs != LQRHS_NEED) {
924+
printf("lsqpack_dec_header_read failure res=%u\n", rhs);
925+
}
890926
}
891927
}
892-
}
893928

894-
CurFrameLengthLeft -= AvailFrameLength;
895-
Offset += AvailFrameLength;
929+
CurFrameLengthLeft -= AvailFrameLength;
930+
CurRecvOffset += AvailFrameLength;
931+
932+
} while (CurRecvOffset < Buffer->Length);
933+
934+
CurRecvCompleteLength += Buffer->Length;
935+
CurRecvOffset = 0;
936+
}
896937

897-
} while (Offset < Buffer->Length);
938+
CurRecvCompleteLength = 0;
939+
940+
return QUIC_STATUS_SUCCESS;
941+
}
942+
943+
void
944+
MsH3BiDirStream::CompleteReceive(
945+
_In_ uint32_t Length
946+
)
947+
{
948+
if (ReceivePending) {
949+
ReceivePending = false;
950+
CurFrameLengthLeft -= Length;
951+
auto CompleteLength = CurRecvCompleteLength + CurRecvOffset + Length;
952+
CurRecvCompleteLength = 0;
953+
CurRecvOffset = 0;
954+
(void)ReceiveComplete(CompleteLength);
955+
}
956+
}
957+
958+
void
959+
MsH3BiDirStream::SetReceiveEnabled(
960+
_In_ bool Enabled
961+
)
962+
{
963+
(void)ReceiveSetEnabled(Enabled);
898964
}
899965

900966
struct lsxpack_header*

lib/msh3.hpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,12 +497,15 @@ struct MsH3BiDirStream : public MsQuicStream {
497497
QUIC_VAR_INT CurFrameType {0};
498498
QUIC_VAR_INT CurFrameLength {0};
499499
QUIC_VAR_INT CurFrameLengthLeft {0};
500+
uint64_t CurRecvCompleteLength {0};
501+
uint32_t CurRecvOffset {0};
500502

501503
uint8_t BufferedHeaders[2*sizeof(uint64_t)];
502504
uint32_t BufferedHeadersLength {0};
503505

504506
bool Complete {false};
505507
bool ShutdownComplete {false};
508+
bool ReceivePending {false};
506509

507510
MsH3BiDirStream(
508511
_In_ MsH3Connection& Connection,
@@ -521,6 +524,16 @@ struct MsH3BiDirStream : public MsQuicStream {
521524
);
522525
#endif
523526

527+
void
528+
CompleteReceive(
529+
_In_ uint32_t Length
530+
);
531+
532+
void
533+
SetReceiveEnabled(
534+
_In_ bool Enabled
535+
);
536+
524537
bool
525538
SendAppData(
526539
_In_ MSH3_REQUEST_FLAGS Flags,
@@ -551,9 +564,9 @@ struct MsH3BiDirStream : public MsQuicStream {
551564

552565
private:
553566

554-
void
567+
QUIC_STATUS
555568
Receive(
556-
_In_ const QUIC_BUFFER* Buffer
569+
_Inout_ QUIC_STREAM_EVENT* Event
557570
);
558571

559572
static QUIC_STATUS

lib/win32/msh3.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ EXPORTS
1111
MsH3ConnectionSetCertificate
1212
MsH3RequestOpen
1313
MsH3RequestClose
14+
MsH3RequestCompleteReceive
15+
MsH3RequestSetReceiveEnabled
1416
MsH3RequestSend
1517
MsH3RequestShutdown
1618
MsH3RequestSetCallbackInterface

msh3.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ MsH3ConnectionSetCertificate(
209209

210210
typedef struct MSH3_REQUEST_IF {
211211
void (MSH3_CALL *HeaderReceived)(MSH3_REQUEST* Request, void* IfContext, const MSH3_HEADER* Header);
212-
void (MSH3_CALL *DataReceived)(MSH3_REQUEST* Request, void* IfContext, uint32_t Length, const uint8_t* Data);
212+
bool (MSH3_CALL *DataReceived)(MSH3_REQUEST* Request, void* IfContext, uint32_t* Length, const uint8_t* Data);
213213
void (MSH3_CALL *Complete)(MSH3_REQUEST* Request, void* IfContext, bool Aborted, uint64_t AbortError);
214214
void (MSH3_CALL *ShutdownComplete)(MSH3_REQUEST* Request, void* IfContext);
215215
void (MSH3_CALL *DataSent)(MSH3_REQUEST* Request, void* IfContext, void* SendContext);
@@ -232,6 +232,20 @@ MsH3RequestClose(
232232
MSH3_REQUEST* Handle
233233
);
234234

235+
void
236+
MSH3_CALL
237+
MsH3RequestCompleteReceive(
238+
MSH3_REQUEST* Handle,
239+
uint32_t Length
240+
);
241+
242+
void
243+
MSH3_CALL
244+
MsH3RequestSetReceiveEnabled(
245+
MSH3_REQUEST* Handle,
246+
bool Enabled
247+
);
248+
235249
bool
236250
MSH3_CALL
237251
MsH3RequestSend(

test/msh3test.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,68 @@ DEF_TEST(SimpleRequest) {
8686
return true;
8787
}
8888

89+
bool ReceiveData(bool Async, bool Inline = true) {
90+
struct Context {
91+
bool Async; bool Inline;
92+
TestWaitable<uint32_t> Data;
93+
Context(bool Async, bool Inline) : Async(Async), Inline(Inline) {}
94+
static bool RecvData(TestRequest* Request, uint32_t* Length, const uint8_t* /* Data */) {
95+
auto ctx = (Context*)Request->AppContext;
96+
ctx->Data.Set(*Length);
97+
if (ctx->Async) {
98+
if (ctx->Inline) {
99+
Request->CompleteReceive(*Length);
100+
}
101+
return false;
102+
}
103+
return true;
104+
}
105+
};
106+
TestApi Api; VERIFY(Api.IsValid());
107+
TestCertificate Cert(Api); VERIFY(Cert.IsValid());
108+
TestListener Listener(Api); VERIFY(Listener.IsValid());
109+
TestConnection Connection(Api); VERIFY(Connection.IsValid());
110+
Context Context(Async, Inline);
111+
TestRequest Request(Connection, RequestHeaders, RequestHeadersCount, MSH3_REQUEST_FLAG_FIN, &Context, nullptr, Context::RecvData);
112+
VERIFY(Request.IsValid());
113+
VERIFY(Listener.NewConnection.WaitFor());
114+
auto ServerConnection = Listener.NewConnection.Get();
115+
ServerConnection->SetCertificate(Cert);
116+
VERIFY(ServerConnection->Connected.WaitFor());
117+
VERIFY(Connection.Connected.WaitFor());
118+
VERIFY(ServerConnection->NewRequest.WaitFor());
119+
auto ServerRequest = ServerConnection->NewRequest.Get();
120+
VERIFY(ServerRequest->Send(MSH3_REQUEST_FLAG_FIN, ResponseData, sizeof(ResponseData)));
121+
VERIFY(Context.Data.WaitFor());
122+
VERIFY(Context.Data.Get() == sizeof(ResponseData));
123+
if (Async && !Inline) {
124+
Request.CompleteReceive(Context.Data.Get());
125+
}
126+
VERIFY(Request.Complete.WaitFor());
127+
VERIFY(ServerRequest->Complete.WaitFor());
128+
return true;
129+
}
130+
131+
DEF_TEST(ReceiveDataInline) {
132+
return ReceiveData(false);
133+
}
134+
135+
DEF_TEST(ReceiveDataAsync) {
136+
return ReceiveData(true, false);
137+
}
138+
139+
DEF_TEST(ReceiveDataAsyncInline) {
140+
return ReceiveData(true, true);
141+
}
142+
89143
const TestFunc TestFunctions[] = {
90144
ADD_TEST(Handshake),
91145
ADD_TEST(HandshakeFail),
92146
//ADD_TEST(HandshakeSetCertTimeout),
93-
ADD_TEST(SimpleRequest)
147+
ADD_TEST(SimpleRequest),
148+
ADD_TEST(ReceiveDataInline),
149+
ADD_TEST(ReceiveDataAsync),
150+
ADD_TEST(ReceiveDataAsyncInline),
94151
};
95152
const uint32_t TestCount = sizeof(TestFunctions)/sizeof(TestFunc);
96153

0 commit comments

Comments
 (0)