Skip to content

Commit c5f1613

Browse files
authoredJan 1, 2023
Receive Path API Improvements (#101)
1 parent 0345e9a commit c5f1613

File tree

7 files changed

+244
-74
lines changed

7 files changed

+244
-74
lines changed
 

‎lib/msh3.cpp

+123-57
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,28 @@ MsH3RequestClose(
170170
delete (MsH3BiDirStream*)Handle;
171171
}
172172

173+
extern "C"
174+
void
175+
MSH3_CALL
176+
MsH3RequestCompleteReceive(
177+
MSH3_REQUEST* Handle,
178+
uint32_t Length
179+
)
180+
{
181+
((MsH3BiDirStream*)Handle)->CompleteReceive(Length);
182+
}
183+
184+
extern "C"
185+
void
186+
MSH3_CALL
187+
MsH3RequestSetReceiveEnabled(
188+
MSH3_REQUEST* Handle,
189+
bool Enabled
190+
)
191+
{
192+
((MsH3BiDirStream*)Handle)->SetReceiveEnabled(Enabled);
193+
}
194+
173195
extern "C"
174196
bool
175197
MSH3_CALL
@@ -801,10 +823,7 @@ MsH3BiDirStream::MsQuicCallback(
801823
}
802824
break;
803825
case QUIC_STREAM_EVENT_RECEIVE:
804-
for (uint32_t i = 0; i < Event->RECEIVE.BufferCount; ++i) {
805-
Receive(Event->RECEIVE.Buffers + i);
806-
}
807-
break;
826+
return Receive(Event);
808827
case QUIC_STREAM_EVENT_SEND_COMPLETE:
809828
if (Event->SEND_COMPLETE.ClientContext) {
810829
auto AppSend = (MsH3AppSend*)Event->SEND_COMPLETE.ClientContext;
@@ -830,71 +849,118 @@ MsH3BiDirStream::MsQuicCallback(
830849
return QUIC_STATUS_SUCCESS;
831850
}
832851

833-
void
852+
QUIC_STATUS
834853
MsH3BiDirStream::Receive(
835-
_In_ const QUIC_BUFFER* Buffer
854+
_Inout_ QUIC_STREAM_EVENT* Event
836855
)
837856
{
838-
uint32_t Offset = 0;
839-
840-
do {
841-
if (CurFrameLengthLeft == 0) {
842-
if (BufferedHeadersLength == 0) {
843-
if (!MsH3VarIntDecode(Buffer->Length, Buffer->Buffer, &Offset, &CurFrameType) ||
844-
!MsH3VarIntDecode(Buffer->Length, Buffer->Buffer, &Offset, &CurFrameLength)) {
845-
BufferedHeadersLength = Buffer->Length - Offset;
846-
memcpy(BufferedHeaders, Buffer->Buffer + Offset, BufferedHeadersLength);
847-
return;
848-
}
849-
} else {
850-
uint32_t ToCopy = sizeof(BufferedHeaders) - BufferedHeadersLength;
851-
if (ToCopy > Buffer->Length) ToCopy = Buffer->Length;
852-
memcpy(BufferedHeaders + BufferedHeadersLength, Buffer->Buffer, ToCopy);
853-
if (!MsH3VarIntDecode(BufferedHeadersLength+ToCopy, BufferedHeaders, &Offset, &CurFrameType) ||
854-
!MsH3VarIntDecode(BufferedHeadersLength+ToCopy, BufferedHeaders, &Offset, &CurFrameLength)) {
855-
BufferedHeadersLength += ToCopy;
856-
return;
857+
for (uint32_t i = 0; i < Event->RECEIVE.BufferCount; ++i) {
858+
const QUIC_BUFFER* Buffer = Event->RECEIVE.Buffers + i;
859+
do {
860+
if (CurFrameLengthLeft == 0) { // Not in the middle of reading frame payload
861+
if (BufferedHeadersLength == 0) { // No partial frame header bufferred
862+
if (!MsH3VarIntDecode(Buffer->Length, Buffer->Buffer, &CurRecvOffset, &CurFrameType) ||
863+
!MsH3VarIntDecode(Buffer->Length, Buffer->Buffer, &CurRecvOffset, &CurFrameLength)) {
864+
BufferedHeadersLength = Buffer->Length - CurRecvOffset;
865+
memcpy(BufferedHeaders, Buffer->Buffer + CurRecvOffset, BufferedHeadersLength);
866+
break;
867+
}
868+
} else { // Partial frame header bufferred already
869+
uint32_t ToCopy = sizeof(BufferedHeaders) - BufferedHeadersLength;
870+
if (ToCopy > Buffer->Length) ToCopy = Buffer->Length;
871+
memcpy(BufferedHeaders + BufferedHeadersLength, Buffer->Buffer, ToCopy);
872+
if (!MsH3VarIntDecode(BufferedHeadersLength+ToCopy, BufferedHeaders, &CurRecvOffset, &CurFrameType) ||
873+
!MsH3VarIntDecode(BufferedHeadersLength+ToCopy, BufferedHeaders, &CurRecvOffset, &CurFrameLength)) {
874+
BufferedHeadersLength += ToCopy;
875+
break;
876+
}
877+
CurRecvOffset -= BufferedHeadersLength;
878+
BufferedHeadersLength = 0;
857879
}
858-
Offset -= BufferedHeadersLength;
859-
BufferedHeadersLength = 0;
880+
CurFrameLengthLeft = CurFrameLength;
860881
}
861-
CurFrameLengthLeft = CurFrameLength;
862-
}
863882

864-
uint32_t AvailFrameLength;
865-
if (Offset + CurFrameLengthLeft > (uint64_t)Buffer->Length) {
866-
AvailFrameLength = Buffer->Length - Offset; // Rest of the buffer
867-
} else {
868-
AvailFrameLength = (uint32_t)CurFrameLengthLeft;
869-
}
883+
uint32_t AvailFrameLength;
884+
if (CurRecvOffset + CurFrameLengthLeft > (uint64_t)Buffer->Length) {
885+
AvailFrameLength = Buffer->Length - CurRecvOffset; // Rest of the buffer
886+
} else {
887+
AvailFrameLength = (uint32_t)CurFrameLengthLeft;
888+
}
870889

871-
if (CurFrameType == H3FrameData) {
872-
Callbacks.DataReceived((MSH3_REQUEST*)this, Context, AvailFrameLength, Buffer->Buffer + Offset);
873-
} else if (CurFrameType == H3FrameHeaders) {
874-
const uint8_t* Frame = Buffer->Buffer + Offset;
875-
if (CurFrameLengthLeft == CurFrameLength) {
876-
auto rhs =
877-
lsqpack_dec_header_in(
878-
&H3.Decoder, this, ID(), (size_t)CurFrameLength, &Frame,
879-
AvailFrameLength, nullptr, nullptr);
880-
if (rhs != LQRHS_DONE && rhs != LQRHS_NEED) {
881-
printf("lsqpack_dec_header_in failure res=%u\n", rhs);
890+
if (CurFrameType == H3FrameData) {
891+
uint32_t AppReceiveLength = AvailFrameLength;
892+
ReceivePending = true;
893+
if (Callbacks.DataReceived((MSH3_REQUEST*)this, Context, &AppReceiveLength, Buffer->Buffer + CurRecvOffset)) {
894+
ReceivePending = false; // Not pending receive
895+
if (AppReceiveLength < AvailFrameLength) { // Partial receive case
896+
CurFrameLengthLeft -= AppReceiveLength;
897+
Event->RECEIVE.TotalBufferLength = CurRecvCompleteLength + CurRecvOffset + AppReceiveLength;
898+
CurRecvCompleteLength = 0;
899+
CurRecvOffset = 0;
900+
return QUIC_STATUS_SUCCESS;
901+
}
902+
} else { // Receive pending (but may have been completed via API call already)
903+
if (!ReceivePending) {
904+
// TODO - Support continuing this receive since it was completed via the API call
905+
}
906+
return QUIC_STATUS_PENDING;
882907
}
883-
} else { // Continued from a previous partial read
884-
auto rhs =
885-
lsqpack_dec_header_read(
886-
&H3.Decoder, this, &Frame, AvailFrameLength, nullptr,
887-
nullptr);
888-
if (rhs != LQRHS_DONE && rhs != LQRHS_NEED) {
889-
printf("lsqpack_dec_header_read failure res=%u\n", rhs);
908+
} else if (CurFrameType == H3FrameHeaders) {
909+
const uint8_t* Frame = Buffer->Buffer + CurRecvOffset;
910+
if (CurFrameLengthLeft == CurFrameLength) {
911+
auto rhs =
912+
lsqpack_dec_header_in(
913+
&H3.Decoder, this, ID(), (size_t)CurFrameLength, &Frame,
914+
AvailFrameLength, nullptr, nullptr);
915+
if (rhs != LQRHS_DONE && rhs != LQRHS_NEED) {
916+
printf("lsqpack_dec_header_in failure res=%u\n", rhs);
917+
}
918+
} else { // Continued from a previous partial read
919+
auto rhs =
920+
lsqpack_dec_header_read(
921+
&H3.Decoder, this, &Frame, AvailFrameLength, nullptr,
922+
nullptr);
923+
if (rhs != LQRHS_DONE && rhs != LQRHS_NEED) {
924+
printf("lsqpack_dec_header_read failure res=%u\n", rhs);
925+
}
890926
}
891927
}
892-
}
893928

894-
CurFrameLengthLeft -= AvailFrameLength;
895-
Offset += AvailFrameLength;
929+
CurFrameLengthLeft -= AvailFrameLength;
930+
CurRecvOffset += AvailFrameLength;
931+
932+
} while (CurRecvOffset < Buffer->Length);
933+
934+
CurRecvCompleteLength += Buffer->Length;
935+
CurRecvOffset = 0;
936+
}
896937

897-
} while (Offset < Buffer->Length);
938+
CurRecvCompleteLength = 0;
939+
940+
return QUIC_STATUS_SUCCESS;
941+
}
942+
943+
void
944+
MsH3BiDirStream::CompleteReceive(
945+
_In_ uint32_t Length
946+
)
947+
{
948+
if (ReceivePending) {
949+
ReceivePending = false;
950+
CurFrameLengthLeft -= Length;
951+
auto CompleteLength = CurRecvCompleteLength + CurRecvOffset + Length;
952+
CurRecvCompleteLength = 0;
953+
CurRecvOffset = 0;
954+
(void)ReceiveComplete(CompleteLength);
955+
}
956+
}
957+
958+
void
959+
MsH3BiDirStream::SetReceiveEnabled(
960+
_In_ bool Enabled
961+
)
962+
{
963+
(void)ReceiveSetEnabled(Enabled);
898964
}
899965

900966
struct lsxpack_header*

‎lib/msh3.hpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -497,12 +497,15 @@ struct MsH3BiDirStream : public MsQuicStream {
497497
QUIC_VAR_INT CurFrameType {0};
498498
QUIC_VAR_INT CurFrameLength {0};
499499
QUIC_VAR_INT CurFrameLengthLeft {0};
500+
uint64_t CurRecvCompleteLength {0};
501+
uint32_t CurRecvOffset {0};
500502

501503
uint8_t BufferedHeaders[2*sizeof(uint64_t)];
502504
uint32_t BufferedHeadersLength {0};
503505

504506
bool Complete {false};
505507
bool ShutdownComplete {false};
508+
bool ReceivePending {false};
506509

507510
MsH3BiDirStream(
508511
_In_ MsH3Connection& Connection,
@@ -521,6 +524,16 @@ struct MsH3BiDirStream : public MsQuicStream {
521524
);
522525
#endif
523526

527+
void
528+
CompleteReceive(
529+
_In_ uint32_t Length
530+
);
531+
532+
void
533+
SetReceiveEnabled(
534+
_In_ bool Enabled
535+
);
536+
524537
bool
525538
SendAppData(
526539
_In_ MSH3_REQUEST_FLAGS Flags,
@@ -551,9 +564,9 @@ struct MsH3BiDirStream : public MsQuicStream {
551564

552565
private:
553566

554-
void
567+
QUIC_STATUS
555568
Receive(
556-
_In_ const QUIC_BUFFER* Buffer
569+
_Inout_ QUIC_STREAM_EVENT* Event
557570
);
558571

559572
static QUIC_STATUS

‎lib/win32/msh3.def

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ EXPORTS
1111
MsH3ConnectionSetCertificate
1212
MsH3RequestOpen
1313
MsH3RequestClose
14+
MsH3RequestCompleteReceive
15+
MsH3RequestSetReceiveEnabled
1416
MsH3RequestSend
1517
MsH3RequestShutdown
1618
MsH3RequestSetCallbackInterface

‎msh3.h

+15-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ MsH3ConnectionSetCertificate(
209209

210210
typedef struct MSH3_REQUEST_IF {
211211
void (MSH3_CALL *HeaderReceived)(MSH3_REQUEST* Request, void* IfContext, const MSH3_HEADER* Header);
212-
void (MSH3_CALL *DataReceived)(MSH3_REQUEST* Request, void* IfContext, uint32_t Length, const uint8_t* Data);
212+
bool (MSH3_CALL *DataReceived)(MSH3_REQUEST* Request, void* IfContext, uint32_t* Length, const uint8_t* Data);
213213
void (MSH3_CALL *Complete)(MSH3_REQUEST* Request, void* IfContext, bool Aborted, uint64_t AbortError);
214214
void (MSH3_CALL *ShutdownComplete)(MSH3_REQUEST* Request, void* IfContext);
215215
void (MSH3_CALL *DataSent)(MSH3_REQUEST* Request, void* IfContext, void* SendContext);
@@ -232,6 +232,20 @@ MsH3RequestClose(
232232
MSH3_REQUEST* Handle
233233
);
234234

235+
void
236+
MSH3_CALL
237+
MsH3RequestCompleteReceive(
238+
MSH3_REQUEST* Handle,
239+
uint32_t Length
240+
);
241+
242+
void
243+
MSH3_CALL
244+
MsH3RequestSetReceiveEnabled(
245+
MSH3_REQUEST* Handle,
246+
bool Enabled
247+
);
248+
235249
bool
236250
MSH3_CALL
237251
MsH3RequestSend(

‎test/msh3test.cpp

+58-1
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,68 @@ DEF_TEST(SimpleRequest) {
8686
return true;
8787
}
8888

89+
bool ReceiveData(bool Async, bool Inline = true) {
90+
struct Context {
91+
bool Async; bool Inline;
92+
TestWaitable<uint32_t> Data;
93+
Context(bool Async, bool Inline) : Async(Async), Inline(Inline) {}
94+
static bool RecvData(TestRequest* Request, uint32_t* Length, const uint8_t* /* Data */) {
95+
auto ctx = (Context*)Request->AppContext;
96+
ctx->Data.Set(*Length);
97+
if (ctx->Async) {
98+
if (ctx->Inline) {
99+
Request->CompleteReceive(*Length);
100+
}
101+
return false;
102+
}
103+
return true;
104+
}
105+
};
106+
TestApi Api; VERIFY(Api.IsValid());
107+
TestCertificate Cert(Api); VERIFY(Cert.IsValid());
108+
TestListener Listener(Api); VERIFY(Listener.IsValid());
109+
TestConnection Connection(Api); VERIFY(Connection.IsValid());
110+
Context Context(Async, Inline);
111+
TestRequest Request(Connection, RequestHeaders, RequestHeadersCount, MSH3_REQUEST_FLAG_FIN, &Context, nullptr, Context::RecvData);
112+
VERIFY(Request.IsValid());
113+
VERIFY(Listener.NewConnection.WaitFor());
114+
auto ServerConnection = Listener.NewConnection.Get();
115+
ServerConnection->SetCertificate(Cert);
116+
VERIFY(ServerConnection->Connected.WaitFor());
117+
VERIFY(Connection.Connected.WaitFor());
118+
VERIFY(ServerConnection->NewRequest.WaitFor());
119+
auto ServerRequest = ServerConnection->NewRequest.Get();
120+
VERIFY(ServerRequest->Send(MSH3_REQUEST_FLAG_FIN, ResponseData, sizeof(ResponseData)));
121+
VERIFY(Context.Data.WaitFor());
122+
VERIFY(Context.Data.Get() == sizeof(ResponseData));
123+
if (Async && !Inline) {
124+
Request.CompleteReceive(Context.Data.Get());
125+
}
126+
VERIFY(Request.Complete.WaitFor());
127+
VERIFY(ServerRequest->Complete.WaitFor());
128+
return true;
129+
}
130+
131+
DEF_TEST(ReceiveDataInline) {
132+
return ReceiveData(false);
133+
}
134+
135+
DEF_TEST(ReceiveDataAsync) {
136+
return ReceiveData(true, false);
137+
}
138+
139+
DEF_TEST(ReceiveDataAsyncInline) {
140+
return ReceiveData(true, true);
141+
}
142+
89143
const TestFunc TestFunctions[] = {
90144
ADD_TEST(Handshake),
91145
ADD_TEST(HandshakeFail),
92146
//ADD_TEST(HandshakeSetCertTimeout),
93-
ADD_TEST(SimpleRequest)
147+
ADD_TEST(SimpleRequest),
148+
ADD_TEST(ReceiveDataInline),
149+
ADD_TEST(ReceiveDataAsync),
150+
ADD_TEST(ReceiveDataAsyncInline),
94151
};
95152
const uint32_t TestCount = sizeof(TestFunctions)/sizeof(TestFunc);
96153

‎test/msh3test.hpp

+28-11
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,27 @@ struct TestConnection {
119119
}
120120
};
121121

122+
typedef void TestHeaderRecvCallback(struct TestRequest* Request, const MSH3_HEADER* Header);
123+
typedef bool TestDataRecvCallback(struct TestRequest* Request, uint32_t* Length, const uint8_t* Data);
124+
122125
struct TestRequest {
123126
MSH3_REQUEST* Handle { nullptr };
124127
TestWaitable<bool> Complete;
125128
TestWaitable<bool> ShutdownComplete;
126129
bool Aborted {false};
127130
uint64_t AbortError {0};
131+
void* AppContext {nullptr};
132+
TestHeaderRecvCallback* HeaderRecv {nullptr};
133+
TestDataRecvCallback* DataRecv {nullptr};
128134
TestRequest(
129135
TestConnection& Connection,
130136
const MSH3_HEADER* Headers,
131137
size_t HeadersCount,
132-
MSH3_REQUEST_FLAGS Flags = MSH3_REQUEST_FLAG_NONE
133-
) noexcept : CleanUp(CleanUpManual) {
138+
MSH3_REQUEST_FLAGS Flags = MSH3_REQUEST_FLAG_NONE,
139+
void* AppContext = nullptr,
140+
TestHeaderRecvCallback* HeaderRecv = nullptr,
141+
TestDataRecvCallback* DataRecv = nullptr
142+
) noexcept : AppContext(AppContext), HeaderRecv(HeaderRecv), DataRecv(DataRecv), CleanUp(CleanUpManual) {
134143
Handle = MsH3RequestOpen(Connection, &Interface, this, Headers, HeadersCount, Flags);
135144
}
136145
TestRequest(MSH3_REQUEST* ServerHandle) noexcept : Handle(ServerHandle), CleanUp(CleanUpAutoDelete) {
@@ -141,13 +150,19 @@ struct TestRequest {
141150
TestRequest operator=(TestRequest& Other) = delete;
142151
bool IsValid() const noexcept { return Handle != nullptr; }
143152
operator MSH3_REQUEST* () const noexcept { return Handle; }
153+
void CompleteReceive(uint32_t Length) noexcept {
154+
MsH3RequestCompleteReceive(Handle, Length);
155+
};
156+
void SetReceiveEnabled(bool Enabled) noexcept {
157+
MsH3RequestSetReceiveEnabled(Handle, Enabled);
158+
};
144159
bool Send(
145160
MSH3_REQUEST_FLAGS Flags,
146161
const void* Data,
147162
uint32_t DataLength,
148-
void* AppContext = nullptr
163+
void* SendContext = nullptr
149164
) noexcept {
150-
return MsH3RequestSend(Handle, Flags, Data, DataLength, AppContext);
165+
return MsH3RequestSend(Handle, Flags, Data, DataLength, SendContext);
151166
}
152167
void Shutdown(
153168
MSH3_REQUEST_SHUTDOWN_FLAGS Flags,
@@ -165,10 +180,6 @@ struct TestRequest {
165180
private:
166181
const TestCleanUpMode CleanUp;
167182
const MSH3_REQUEST_IF Interface { s_OnHeaderReceived, s_OnDataReceived, s_OnComplete, s_OnShutdownComplete, s_OnDataSent };
168-
void OnHeaderReceived(const MSH3_HEADER* /*Header*/) noexcept {
169-
}
170-
void OnDataReceived(uint32_t /*Length*/, const uint8_t* /*Data*/) noexcept {
171-
}
172183
void OnComplete(bool _Aborted, uint64_t _AbortError) noexcept {
173184
Aborted = _Aborted;
174185
AbortError = _AbortError;
@@ -184,10 +195,16 @@ struct TestRequest {
184195
}
185196
private: // Static stuff
186197
static void MSH3_CALL s_OnHeaderReceived(MSH3_REQUEST* /*Request*/, void* IfContext, const MSH3_HEADER* Header) noexcept {
187-
((TestRequest*)IfContext)->OnHeaderReceived(Header);
198+
if (((TestRequest*)IfContext)->HeaderRecv) {
199+
((TestRequest*)IfContext)->HeaderRecv((TestRequest*)IfContext, Header);
200+
}
188201
}
189-
static void MSH3_CALL s_OnDataReceived(MSH3_REQUEST* /*Request*/, void* IfContext, uint32_t Length, const uint8_t* Data) noexcept {
190-
((TestRequest*)IfContext)->OnDataReceived(Length, Data);
202+
static bool MSH3_CALL s_OnDataReceived(MSH3_REQUEST* /*Request*/, void* IfContext, uint32_t* Length, const uint8_t* Data) noexcept {
203+
if (((TestRequest*)IfContext)->DataRecv) {
204+
return ((TestRequest*)IfContext)->DataRecv((TestRequest*)IfContext, Length, Data);
205+
} else {
206+
return true;
207+
}
191208
}
192209
static void MSH3_CALL s_OnComplete(MSH3_REQUEST* /*Request*/, void* IfContext, bool Aborted, uint64_t AbortError) noexcept {
193210
((TestRequest*)IfContext)->OnComplete(Aborted, AbortError);

‎tool/msh3_app.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ void MSH3_CALL HeaderReceived(MSH3_REQUEST* , void* , const MSH3_HEADER* Header)
5858
}
5959
}
6060

61-
void MSH3_CALL DataReceived(MSH3_REQUEST* , void* , uint32_t Length, const uint8_t* Data) {
62-
if (Args.Print) fwrite(Data, 1, Length, stdout);
61+
bool MSH3_CALL DataReceived(MSH3_REQUEST* , void* , uint32_t* Length, const uint8_t* Data) {
62+
if (Args.Print) fwrite(Data, 1, *Length, stdout);
63+
return true;
6364
}
6465

6566
void MSH3_CALL Complete(MSH3_REQUEST* , void* Context, bool Aborted, uint64_t AbortError) {

0 commit comments

Comments
 (0)
Please sign in to comment.