diff --git a/peerconnection.go b/peerconnection.go index db4a4a80d8a..40637c38945 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -1784,7 +1784,20 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err return err } + // try to read simulcast IDs from the packet we already have var mid, rid, rsid string + if _, err = handleUnknownRTPPacket( + b[:i], uint8(midExtensionID), //nolint:gosec // G115 + uint8(streamIDExtensionID), //nolint:gosec // G115 + uint8(repairStreamIDExtensionID), //nolint:gosec // G115 + &mid, + &rid, + &rsid, + ); err != nil { + return err + } + + // if the first packet didn't contain simuilcast IDs, then probe more packets var paddingOnly bool for readCount := 0; readCount <= simulcastProbeCount; readCount++ { if mid == "" || (rid == "" && rsid == "") { @@ -1798,7 +1811,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err return err } - if _, paddingOnly, err = handleUnknownRTPPacket( + if paddingOnly, err = handleUnknownRTPPacket( b[:i], uint8(midExtensionID), //nolint:gosec // G115 uint8(streamIDExtensionID), //nolint:gosec // G115 uint8(repairStreamIDExtensionID), //nolint:gosec // G115 diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index 2219c2900d1..99ae6e33db3 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -1131,6 +1131,110 @@ func TestPeerConnection_Simulcast_Probe(t *testing.T) { //nolint:cyclop close(testFinished) }) + // Assert that we can send just one packet with Simulcast IDs (using extensions) and they will be properly received + t.Run("ExtractIDs", func(t *testing.T) { + offerer, answerer, err := newPair() + assert.NoError(t, err) + + rids := []string{"layer_1", "layer_2", "layer_3"} + ridSelected := rids[0] + + onTrackCalled := atomicBool{} + answerer.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) { + assert.Equal(t, remote.rid, ridSelected) + onTrackCalled.set(true) + }) + + vp8WriterA, err := NewTrackLocalStaticRTP( + RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion1", WithRTPStreamID(rids[0]), + ) + assert.NoError(t, err) + + vp8WriterB, err := NewTrackLocalStaticRTP( + RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion1", WithRTPStreamID(rids[1]), + ) + assert.NoError(t, err) + + vp8WriterC, err := NewTrackLocalStaticRTP( + RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion1", WithRTPStreamID(rids[2]), + ) + assert.NoError(t, err) + + sender, err := offerer.AddTrack(vp8WriterA) + assert.NoError(t, err) + assert.NotNil(t, sender) + + assert.NoError(t, sender.AddEncoding(vp8WriterB)) + assert.NoError(t, sender.AddEncoding(vp8WriterC)) + + assert.NoError(t, signalPair(offerer, answerer)) + + peerConnectionConnected := untilConnectionState(PeerConnectionStateConnected, offerer, answerer) + peerConnectionConnected.Wait() + + ticker := time.NewTicker(time.Millisecond * 20) + defer ticker.Stop() + testFinished := make(chan struct{}) + seenOneStream, seenOneStreamCancel := context.WithCancel(context.Background()) + + go func() { + sentOnePacket := false + + senderTrack := vp8WriterA + + for { + select { + case <-testFinished: + return + case <-ticker.C: + answerer.dtlsTransport.lock.Lock() + if len(answerer.dtlsTransport.simulcastStreams) >= 1 { + seenOneStreamCancel() + } + answerer.dtlsTransport.lock.Unlock() + + senderTrack.mu.Lock() + + // We send just one packet with the RID, that's the point of this test + if !sentOnePacket && len(senderTrack.bindings) > 0 { + sentOnePacket = true + + midExtensionID, _, _ := answerer.api.mediaEngine.getHeaderExtensionID( + RTPHeaderExtensionCapability{sdp.SDESMidURI}, + ) + assert.Greater(t, midExtensionID, 0) + + streamIDExtensionID, _, _ := answerer.api.mediaEngine.getHeaderExtensionID( + RTPHeaderExtensionCapability{sdp.SDESRTPStreamIDURI}, + ) + assert.Greater(t, streamIDExtensionID, 0) + + header := &rtp.Header{ + Version: 2, + SSRC: util.RandUint32(), + } + header.Extension = true + header.ExtensionProfile = 0x1000 + assert.NoError(t, header.SetExtension(uint8(midExtensionID), []byte("0"))) + assert.NoError(t, header.SetExtension(uint8(streamIDExtensionID), []byte(ridSelected))) + + _, err = senderTrack.bindings[0].writeStream.WriteRTP(header, []byte{0, 1, 2, 3, 4, 5}) + assert.NoError(t, err) + } + + senderTrack.mu.Unlock() + } + } + }() + + <-seenOneStream.Done() + + assert.Equal(t, true, onTrackCalled.get()) + + closePairNow(t, offerer, answerer) + close(testFinished) + }) + // Assert that NonSimulcast Traffic isn't incorrectly broken by the probe t.Run("Break NonSimulcast", func(t *testing.T) { unhandledSimulcastError := make(chan struct{}) diff --git a/rtptransceiver.go b/rtptransceiver.go index 5ce0f2d4819..f53a16b4fc3 100644 --- a/rtptransceiver.go +++ b/rtptransceiver.go @@ -291,10 +291,10 @@ func handleUnknownRTPPacket( streamIDExtensionID, repairStreamIDExtensionID uint8, mid, rid, rsid *string, -) (payloadType PayloadType, paddingOnly bool, err error) { +) (paddingOnly bool, err error) { rp := &rtp.Packet{} if err = rp.Unmarshal(buf); err != nil { - return 0, false, err + return false, err } if rp.Padding && len(rp.Payload) == 0 { @@ -302,10 +302,9 @@ func handleUnknownRTPPacket( } if !rp.Header.Extension { - return payloadType, paddingOnly, nil + return paddingOnly, nil } - payloadType = PayloadType(rp.PayloadType) if payload := rp.GetExtension(midExtensionID); payload != nil { *mid = string(payload) } @@ -318,5 +317,5 @@ func handleUnknownRTPPacket( *rsid = string(payload) } - return payloadType, paddingOnly, nil + return paddingOnly, nil }