From 8f131c2af7ede9d730ab15afc31fced0c3fa9a07 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Sat, 9 Nov 2019 12:28:16 -0500 Subject: [PATCH 1/3] Add an alternative frame reader --- frame-reader-2.go | 46 ++++++++++++++ frame-reader-2_test.go | 139 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 frame-reader-2.go create mode 100644 frame-reader-2_test.go diff --git a/frame-reader-2.go b/frame-reader-2.go new file mode 100644 index 0000000..4085419 --- /dev/null +++ b/frame-reader-2.go @@ -0,0 +1,46 @@ +// Read a generic "framed" packet consisting of a header and a +// This is used for both TLS Records and TLS Handshake Messages +package mint + +type framing2 interface { + parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) +} + +type frameReader2 struct { + details framing2 + remainder []byte +} + +func newFrameReader2(d framing2) *frameReader2 { + return &frameReader2{ + details: d, + remainder: make([]byte, 0), + } +} + +func (f *frameReader2) ready() bool { + headerReady, headerLen, bodyLen := f.details.parse(f.remainder) + return headerReady && len(f.remainder) >= headerLen+bodyLen +} + +func (f *frameReader2) addChunk(in []byte) { + // Append to the buffer + logf(logTypeFrameReader, "Appending %v", len(in)) + f.remainder = append(f.remainder, in...) +} + +func (f *frameReader2) next() ([]byte, []byte, error) { + // Check to see if we have enough data + headerReady, headerLen, bodyLen := f.details.parse(f.remainder) + if !headerReady || len(f.remainder) < headerLen+bodyLen { + logf(logTypeVerbose, "Read would have blocked") + return nil, nil, AlertWouldBlock + } + + // Read a record off the front of the buffer + header, body := make([]byte, headerLen), make([]byte, bodyLen) + copy(header, f.remainder[:headerLen]) + copy(body, f.remainder[headerLen:headerLen+bodyLen]) + f.remainder = f.remainder[headerLen+bodyLen:] + return header, body, nil +} diff --git a/frame-reader-2_test.go b/frame-reader-2_test.go new file mode 100644 index 0000000..e5666b5 --- /dev/null +++ b/frame-reader-2_test.go @@ -0,0 +1,139 @@ +package mint + +import ( + "strings" + "testing" + + "github.com/bifurcation/mint/syntax" +) + +var ( + simpleFullFrame = unhex("00056162636465") + simpleEmptyFrame = unhex("0000") + variableFullFrame = unhex("40ff" + strings.Repeat("A0", 255)) + variableEmptyFrame = unhex("00") +) + +type simpleHeader2 struct{} + +func (h simpleHeader2) parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) { + headerReady = len(buffer) >= 2 + if !headerReady { + return + } + + headerLen = 2 + bodyLen = (int(buffer[0]) << 8) + int(buffer[1]) + return +} + +type variableHeader struct{} + +func (h variableHeader) parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) { + if len(buffer) == 0 { + headerReady = false + return + } + + // XXX: Need a way to return parse errors other than "insufficient data" + length := struct { + Value uint64 `tls:"varint"` + }{} + read, err := syntax.Unmarshal(buffer, &length) + + headerReady = (err == nil) + if !headerReady { + return + } + + headerLen = read + bodyLen = int(length.Value) + return +} + +type frameReaderTester struct { + details framing2 + headerLenFull int + fullFrame []byte + headerLenEmpty int + emptyFrame []byte +} + +func (frt frameReaderTester) checkFrameFull(t *testing.T, hdr, body []byte) { + assertByteEquals(t, hdr, frt.fullFrame[:frt.headerLenFull]) + assertByteEquals(t, body, frt.fullFrame[frt.headerLenFull:]) +} + +func (frt frameReaderTester) checkFrameEmpty(t *testing.T, hdr, body []byte) { + assertByteEquals(t, hdr, frt.emptyFrame[:frt.headerLenEmpty]) + assertByteEquals(t, body, frt.emptyFrame[frt.headerLenEmpty:]) +} + +func (frt frameReaderTester) TestFrames(t *testing.T) { + r := newFrameReader2(frt.details) + r.addChunk(frt.fullFrame) + hdr, body, err := r.next() + assertNotError(t, err, "Couldn't read frame 1") + frt.checkFrameFull(t, hdr, body) + + r.addChunk(frt.emptyFrame) + hdr, body, err = r.next() + assertNotError(t, err, "Couldn't read frame 2") + frt.checkFrameEmpty(t, hdr, body) +} + +func (frt frameReaderTester) TestTwoFrames(t *testing.T) { + r := newFrameReader2(frt.details) + r.addChunk(frt.fullFrame) + r.addChunk(frt.fullFrame) + hdr, body, err := r.next() + assertNotError(t, err, "Couldn't read frame 1") + frt.checkFrameFull(t, hdr, body) + + hdr, body, err = r.next() + assertNotError(t, err, "Couldn't read frame 2") + frt.checkFrameFull(t, hdr, body) +} + +func (frt frameReaderTester) TestTrickle(t *testing.T) { + r := newFrameReader2(frt.details) + + var hdr, body []byte + var err error + for i := 0; i <= len(frt.fullFrame); i += 1 { + hdr, body, err = r.next() + if i < len(frt.fullFrame) { + assertEquals(t, err, AlertWouldBlock) + assertEquals(t, 0, len(hdr)) + assertEquals(t, 0, len(body)) + r.addChunk(frt.fullFrame[i : i+1]) + } + } + assertNil(t, err, "Error reading") + frt.checkFrameFull(t, hdr, body) +} + +func (frt frameReaderTester) Run(t *testing.T) { + t.Run("frames", frt.TestFrames) + t.Run("two-frames", frt.TestTwoFrames) + t.Run("trickle", frt.TestTrickle) +} + +func TestFrameReader2(t *testing.T) { + cases := map[string]frameReaderTester{ + "simple": frameReaderTester{ + simpleHeader2{}, + 2, simpleFullFrame, + 2, simpleEmptyFrame, + }, + "variable": frameReaderTester{ + variableHeader{}, + 2, variableFullFrame, + 1, variableEmptyFrame, + }, + } + + for label, c := range cases { + t.Run(label, c.Run) + } +} From 07c556d768f0c76f3ab4256e4767fe35a47d1c4c Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Sat, 9 Nov 2019 13:00:55 -0500 Subject: [PATCH 2/3] Migrate record layer and handshake layer to new frame reader --- frame-reader-2.go | 18 ++++++++++++++++++ frame-reader-2_test.go | 25 ++++++------------------- handshake-layer.go | 12 +++++++----- record-layer.go | 32 +++++++------------------------- 4 files changed, 38 insertions(+), 49 deletions(-) diff --git a/frame-reader-2.go b/frame-reader-2.go index 4085419..ce4f99d 100644 --- a/frame-reader-2.go +++ b/frame-reader-2.go @@ -6,6 +6,23 @@ type framing2 interface { parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) } +type lastNBytesFraming struct { + headerSize int + lengthSize int +} + +func (lnb lastNBytesFraming) parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) { + headerReady = len(buffer) >= lnb.headerSize + if !headerReady { + return + } + + headerLen = lnb.headerSize + val, _ := decodeUint(buffer[lnb.headerSize-lnb.lengthSize:], lnb.lengthSize) + bodyLen = int(val) + return +} + type frameReader2 struct { details framing2 remainder []byte @@ -20,6 +37,7 @@ func newFrameReader2(d framing2) *frameReader2 { func (f *frameReader2) ready() bool { headerReady, headerLen, bodyLen := f.details.parse(f.remainder) + //logf(logTypeFrameReader, "header=%v body=(%v > %v)", headerReady, len(f.remainder), headerLen+bodyLen) return headerReady && len(f.remainder) >= headerLen+bodyLen } diff --git a/frame-reader-2_test.go b/frame-reader-2_test.go index e5666b5..fb68014 100644 --- a/frame-reader-2_test.go +++ b/frame-reader-2_test.go @@ -8,25 +8,12 @@ import ( ) var ( - simpleFullFrame = unhex("00056162636465") - simpleEmptyFrame = unhex("0000") + fixedFullFrame = unhex("ff00056162636465") + fixedEmptyFrame = unhex("ff0000") variableFullFrame = unhex("40ff" + strings.Repeat("A0", 255)) variableEmptyFrame = unhex("00") ) -type simpleHeader2 struct{} - -func (h simpleHeader2) parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) { - headerReady = len(buffer) >= 2 - if !headerReady { - return - } - - headerLen = 2 - bodyLen = (int(buffer[0]) << 8) + int(buffer[1]) - return -} - type variableHeader struct{} func (h variableHeader) parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) { @@ -121,10 +108,10 @@ func (frt frameReaderTester) Run(t *testing.T) { func TestFrameReader2(t *testing.T) { cases := map[string]frameReaderTester{ - "simple": frameReaderTester{ - simpleHeader2{}, - 2, simpleFullFrame, - 2, simpleEmptyFrame, + "fixed": frameReaderTester{ + lastNBytesFraming{3, 2}, + 3, fixedFullFrame, + 3, fixedEmptyFrame, }, "variable": frameReaderTester{ variableHeader{}, diff --git a/handshake-layer.go b/handshake-layer.go index ae11cb8..0ba434b 100644 --- a/handshake-layer.go +++ b/handshake-layer.go @@ -121,7 +121,7 @@ type HandshakeLayer struct { ctx *HandshakeContext // The handshake we are attached to nonblocking bool // Should we operate in nonblocking mode conn RecordLayer // Used for reading/writing records - frame *frameReader // The buffered frame reader + frame *frameReader2 // The buffered frame reader datagram bool // Is this DTLS? msgSeq uint32 // The DTLS message sequence number queued []*HandshakeMessage // In/out queue @@ -130,6 +130,7 @@ type HandshakeLayer struct { maxFragmentLen int } +/* type handshakeLayerFrameDetails struct { datagram bool } @@ -152,13 +153,14 @@ func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { val, _ := decodeUint(hdr[len(hdr)-3:], 3) return int(val), nil } +*/ func NewHandshakeLayerTLS(c *HandshakeContext, r RecordLayer) *HandshakeLayer { h := HandshakeLayer{} h.ctx = c h.conn = r h.datagram = false - h.frame = newFrameReader(&handshakeLayerFrameDetails{false}) + h.frame = newFrameReader2(lastNBytesFraming{handshakeHeaderLenTLS, 3}) h.maxFragmentLen = maxFragmentLen return &h } @@ -168,7 +170,7 @@ func NewHandshakeLayerDTLS(c *HandshakeContext, r RecordLayer) *HandshakeLayer { h.ctx = c h.conn = r h.datagram = true - h.frame = newFrameReader(&handshakeLayerFrameDetails{true}) + h.frame = newFrameReader2(lastNBytesFraming{handshakeHeaderLenDTLS, 3}) h.maxFragmentLen = initialMtu // Not quite right return &h } @@ -359,7 +361,7 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { } for { logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder)) - if h.frame.needed() > 0 { + if !h.frame.ready() { logf(logTypeVerbose, "Trying to read a new record") err = h.readRecord() @@ -368,7 +370,7 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { } } - hdr, body, err = h.frame.process() + hdr, body, err = h.frame.next() if err == nil { break } diff --git a/record-layer.go b/record-layer.go index ccb90bd..5b90df5 100644 --- a/record-layer.go +++ b/record-layer.go @@ -11,6 +11,7 @@ const ( sequenceNumberLen = 8 // sequence number length recordHeaderLenTLS = 5 // record header length (TLS) recordHeaderLenDTLS = 13 // record header length (DTLS) + maxHeaderLen = 256 // invented upper bound for header size maxFragmentLen = 1 << 14 // max number of bytes in a record labelForKey = "key" labelForIV = "iv" @@ -88,7 +89,7 @@ type DefaultRecordLayer struct { direction Direction version uint16 // The current version number conn io.ReadWriter // The underlying connection - frame *frameReader // The buffered frame reader + frame *frameReader2 // The buffered frame reader nextData []byte // The next record to send cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" cachedError error // Error on the last record read @@ -103,25 +104,6 @@ func (r *DefaultRecordLayer) Impl() *DefaultRecordLayer { return r } -type recordLayerFrameDetails struct { - datagram bool -} - -func (d recordLayerFrameDetails) headerLen() int { - if d.datagram { - return recordHeaderLenDTLS - } - return recordHeaderLenTLS -} - -func (d recordLayerFrameDetails) defaultReadLen() int { - return d.headerLen() + maxFragmentLen -} - -func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { - return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil -} - func newCipherStateNull() *cipherState { return &cipherState{EpochClear, 0, 0, nil, nil} } @@ -140,7 +122,7 @@ func NewRecordLayerTLS(conn io.ReadWriter, dir Direction) *DefaultRecordLayer { r.label = "" r.direction = dir r.conn = conn - r.frame = newFrameReader(recordLayerFrameDetails{false}) + r.frame = newFrameReader2(lastNBytesFraming{recordHeaderLenTLS, 2}) r.cipher = newCipherStateNull() r.version = tls10Version return &r @@ -151,7 +133,7 @@ func NewRecordLayerDTLS(conn io.ReadWriter, dir Direction) *DefaultRecordLayer { r.label = "" r.direction = dir r.conn = conn - r.frame = newFrameReader(recordLayerFrameDetails{true}) + r.frame = newFrameReader2(lastNBytesFraming{recordHeaderLenDTLS, 2}) r.cipher = newCipherStateNull() r.readCiphers = make(map[Epoch]*cipherState, 0) r.readCiphers[0] = r.cipher @@ -352,8 +334,8 @@ func (r *DefaultRecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, erro var header, body []byte for err != nil { - if r.frame.needed() > 0 { - buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen) + if !r.frame.ready() { + buf := make([]byte, maxHeaderLen+maxFragmentLen) n, err := r.conn.Read(buf) if err != nil { logf(logTypeIO, "%s Error reading, %v", r.label, err) @@ -370,7 +352,7 @@ func (r *DefaultRecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, erro r.frame.addChunk(buf) } - header, body, err = r.frame.process() + header, body, err = r.frame.next() // Loop around onAlertWouldBlock to see if some // data is now available. if err != nil && err != AlertWouldBlock { From ea6f506c4d9ac99c5512625671b4615c8fe87b31 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Sat, 9 Nov 2019 13:03:08 -0500 Subject: [PATCH 3/3] Remvoe the old implementation and '2' labeling --- frame-reader-2.go | 64 -------------------- frame-reader-2_test.go | 126 ---------------------------------------- frame-reader.go | 108 ++++++++++++---------------------- frame-reader_test.go | 129 ++++++++++++++++++++++++++++------------- handshake-layer.go | 6 +- record-layer.go | 6 +- 6 files changed, 133 insertions(+), 306 deletions(-) delete mode 100644 frame-reader-2.go delete mode 100644 frame-reader-2_test.go diff --git a/frame-reader-2.go b/frame-reader-2.go deleted file mode 100644 index ce4f99d..0000000 --- a/frame-reader-2.go +++ /dev/null @@ -1,64 +0,0 @@ -// Read a generic "framed" packet consisting of a header and a -// This is used for both TLS Records and TLS Handshake Messages -package mint - -type framing2 interface { - parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) -} - -type lastNBytesFraming struct { - headerSize int - lengthSize int -} - -func (lnb lastNBytesFraming) parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) { - headerReady = len(buffer) >= lnb.headerSize - if !headerReady { - return - } - - headerLen = lnb.headerSize - val, _ := decodeUint(buffer[lnb.headerSize-lnb.lengthSize:], lnb.lengthSize) - bodyLen = int(val) - return -} - -type frameReader2 struct { - details framing2 - remainder []byte -} - -func newFrameReader2(d framing2) *frameReader2 { - return &frameReader2{ - details: d, - remainder: make([]byte, 0), - } -} - -func (f *frameReader2) ready() bool { - headerReady, headerLen, bodyLen := f.details.parse(f.remainder) - //logf(logTypeFrameReader, "header=%v body=(%v > %v)", headerReady, len(f.remainder), headerLen+bodyLen) - return headerReady && len(f.remainder) >= headerLen+bodyLen -} - -func (f *frameReader2) addChunk(in []byte) { - // Append to the buffer - logf(logTypeFrameReader, "Appending %v", len(in)) - f.remainder = append(f.remainder, in...) -} - -func (f *frameReader2) next() ([]byte, []byte, error) { - // Check to see if we have enough data - headerReady, headerLen, bodyLen := f.details.parse(f.remainder) - if !headerReady || len(f.remainder) < headerLen+bodyLen { - logf(logTypeVerbose, "Read would have blocked") - return nil, nil, AlertWouldBlock - } - - // Read a record off the front of the buffer - header, body := make([]byte, headerLen), make([]byte, bodyLen) - copy(header, f.remainder[:headerLen]) - copy(body, f.remainder[headerLen:headerLen+bodyLen]) - f.remainder = f.remainder[headerLen+bodyLen:] - return header, body, nil -} diff --git a/frame-reader-2_test.go b/frame-reader-2_test.go deleted file mode 100644 index fb68014..0000000 --- a/frame-reader-2_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package mint - -import ( - "strings" - "testing" - - "github.com/bifurcation/mint/syntax" -) - -var ( - fixedFullFrame = unhex("ff00056162636465") - fixedEmptyFrame = unhex("ff0000") - variableFullFrame = unhex("40ff" + strings.Repeat("A0", 255)) - variableEmptyFrame = unhex("00") -) - -type variableHeader struct{} - -func (h variableHeader) parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) { - if len(buffer) == 0 { - headerReady = false - return - } - - // XXX: Need a way to return parse errors other than "insufficient data" - length := struct { - Value uint64 `tls:"varint"` - }{} - read, err := syntax.Unmarshal(buffer, &length) - - headerReady = (err == nil) - if !headerReady { - return - } - - headerLen = read - bodyLen = int(length.Value) - return -} - -type frameReaderTester struct { - details framing2 - headerLenFull int - fullFrame []byte - headerLenEmpty int - emptyFrame []byte -} - -func (frt frameReaderTester) checkFrameFull(t *testing.T, hdr, body []byte) { - assertByteEquals(t, hdr, frt.fullFrame[:frt.headerLenFull]) - assertByteEquals(t, body, frt.fullFrame[frt.headerLenFull:]) -} - -func (frt frameReaderTester) checkFrameEmpty(t *testing.T, hdr, body []byte) { - assertByteEquals(t, hdr, frt.emptyFrame[:frt.headerLenEmpty]) - assertByteEquals(t, body, frt.emptyFrame[frt.headerLenEmpty:]) -} - -func (frt frameReaderTester) TestFrames(t *testing.T) { - r := newFrameReader2(frt.details) - r.addChunk(frt.fullFrame) - hdr, body, err := r.next() - assertNotError(t, err, "Couldn't read frame 1") - frt.checkFrameFull(t, hdr, body) - - r.addChunk(frt.emptyFrame) - hdr, body, err = r.next() - assertNotError(t, err, "Couldn't read frame 2") - frt.checkFrameEmpty(t, hdr, body) -} - -func (frt frameReaderTester) TestTwoFrames(t *testing.T) { - r := newFrameReader2(frt.details) - r.addChunk(frt.fullFrame) - r.addChunk(frt.fullFrame) - hdr, body, err := r.next() - assertNotError(t, err, "Couldn't read frame 1") - frt.checkFrameFull(t, hdr, body) - - hdr, body, err = r.next() - assertNotError(t, err, "Couldn't read frame 2") - frt.checkFrameFull(t, hdr, body) -} - -func (frt frameReaderTester) TestTrickle(t *testing.T) { - r := newFrameReader2(frt.details) - - var hdr, body []byte - var err error - for i := 0; i <= len(frt.fullFrame); i += 1 { - hdr, body, err = r.next() - if i < len(frt.fullFrame) { - assertEquals(t, err, AlertWouldBlock) - assertEquals(t, 0, len(hdr)) - assertEquals(t, 0, len(body)) - r.addChunk(frt.fullFrame[i : i+1]) - } - } - assertNil(t, err, "Error reading") - frt.checkFrameFull(t, hdr, body) -} - -func (frt frameReaderTester) Run(t *testing.T) { - t.Run("frames", frt.TestFrames) - t.Run("two-frames", frt.TestTwoFrames) - t.Run("trickle", frt.TestTrickle) -} - -func TestFrameReader2(t *testing.T) { - cases := map[string]frameReaderTester{ - "fixed": frameReaderTester{ - lastNBytesFraming{3, 2}, - 3, fixedFullFrame, - 3, fixedEmptyFrame, - }, - "variable": frameReaderTester{ - variableHeader{}, - 2, variableFullFrame, - 1, variableEmptyFrame, - }, - } - - for label, c := range cases { - t.Run(label, c.Run) - } -} diff --git a/frame-reader.go b/frame-reader.go index 4ccfc23..22eac2b 100644 --- a/frame-reader.go +++ b/frame-reader.go @@ -3,96 +3,62 @@ package mint type framing interface { - headerLen() int - defaultReadLen() int - frameLen(hdr []byte) (int, error) + parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) } -const ( - kFrameReaderHdr = 0 - kFrameReaderBody = 1 -) +type lastNBytesFraming struct { + headerSize int + lengthSize int +} -type frameNextAction func(f *frameReader) error +func (lnb lastNBytesFraming) parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) { + headerReady = len(buffer) >= lnb.headerSize + if !headerReady { + return + } + + headerLen = lnb.headerSize + val, _ := decodeUint(buffer[lnb.headerSize-lnb.lengthSize:], lnb.lengthSize) + bodyLen = int(val) + return +} type frameReader struct { - details framing - state uint8 - header []byte - body []byte - working []byte - writeOffset int - remainder []byte + details framing + remainder []byte } func newFrameReader(d framing) *frameReader { - hdr := make([]byte, d.headerLen()) return &frameReader{ - d, - kFrameReaderHdr, - hdr, - nil, - hdr, - 0, - nil, + details: d, + remainder: make([]byte, 0), } } -func dup(a []byte) []byte { - r := make([]byte, len(a)) - copy(r, a) - return r -} - -func (f *frameReader) needed() int { - tmp := (len(f.working) - f.writeOffset) - len(f.remainder) - if tmp < 0 { - return 0 - } - return tmp +func (f *frameReader) ready() bool { + headerReady, headerLen, bodyLen := f.details.parse(f.remainder) + //logf(logTypeFrameReader, "header=%v body=(%v > %v)", headerReady, len(f.remainder), headerLen+bodyLen) + return headerReady && len(f.remainder) >= headerLen+bodyLen } func (f *frameReader) addChunk(in []byte) { - // Append to the buffer. + // Append to the buffer logf(logTypeFrameReader, "Appending %v", len(in)) f.remainder = append(f.remainder, in...) } -func (f *frameReader) process() (hdr []byte, body []byte, err error) { - for f.needed() == 0 { - logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset) - // Fill out our working block - copied := copy(f.working[f.writeOffset:], f.remainder) - f.remainder = f.remainder[copied:] - f.writeOffset += copied - if f.writeOffset < len(f.working) { - logf(logTypeVerbose, "Read would have blocked 1") - return nil, nil, AlertWouldBlock - } - // Reset the write offset, because we are now full. - f.writeOffset = 0 - - // We have read a full frame - if f.state == kFrameReaderBody { - logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder)) - f.state = kFrameReaderHdr - f.working = f.header - return dup(f.header), dup(f.body), nil - } - - // We have read the header - bodyLen, err := f.details.frameLen(f.header) - if err != nil { - return nil, nil, err - } - logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen) - - f.body = make([]byte, bodyLen) - f.working = f.body - f.writeOffset = 0 - f.state = kFrameReaderBody +func (f *frameReader) next() ([]byte, []byte, error) { + // Check to see if we have enough data + headerReady, headerLen, bodyLen := f.details.parse(f.remainder) + if !headerReady || len(f.remainder) < headerLen+bodyLen { + logf(logTypeVerbose, "Read would have blocked") + return nil, nil, AlertWouldBlock } - logf(logTypeVerbose, "Read would have blocked 2") - return nil, nil, AlertWouldBlock + // Read a record off the front of the buffer + header, body := make([]byte, headerLen), make([]byte, bodyLen) + copy(header, f.remainder[:headerLen]) + copy(body, f.remainder[headerLen:headerLen+bodyLen]) + f.remainder = f.remainder[headerLen+bodyLen:] + return header, body, nil } diff --git a/frame-reader_test.go b/frame-reader_test.go index 4ea5efd..3d29508 100644 --- a/frame-reader_test.go +++ b/frame-reader_test.go @@ -1,75 +1,126 @@ package mint import ( + "strings" "testing" + + "github.com/bifurcation/mint/syntax" ) -var kTestFrame = []byte{0x00, 0x05, 'a', 'b', 'c', 'd', 'e'} -var kTestEmptyFrame = []byte{0x00, 0x00} +var ( + fixedFullFrame = unhex("ff00056162636465") + fixedEmptyFrame = unhex("ff0000") + variableFullFrame = unhex("40ff" + strings.Repeat("A0", 255)) + variableEmptyFrame = unhex("00") +) -type simpleHeader struct{} +type variableHeader struct{} -func (h simpleHeader) headerLen() int { - return 2 -} +func (h variableHeader) parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) { + if len(buffer) == 0 { + headerReady = false + return + } -func (h simpleHeader) defaultReadLen() int { - return 1024 -} + // XXX: Need a way to return parse errors other than "insufficient data" + length := struct { + Value uint64 `tls:"varint"` + }{} + read, err := syntax.Unmarshal(buffer, &length) -func (h simpleHeader) frameLen(hdr []byte) (int, error) { - if len(hdr) != 2 { - panic("Assert!") + headerReady = (err == nil) + if !headerReady { + return } - return (int(hdr[0]) << 8) | int(hdr[1]), nil + headerLen = read + bodyLen = int(length.Value) + return +} + +type frameReaderTester struct { + details framing + headerLenFull int + fullFrame []byte + headerLenEmpty int + emptyFrame []byte +} + +func (frt frameReaderTester) checkFrameFull(t *testing.T, hdr, body []byte) { + assertByteEquals(t, hdr, frt.fullFrame[:frt.headerLenFull]) + assertByteEquals(t, body, frt.fullFrame[frt.headerLenFull:]) } -func checkFrame(t *testing.T, hdr []byte, body []byte) { - assertByteEquals(t, hdr, kTestFrame[:2]) - assertByteEquals(t, body, kTestFrame[2:]) +func (frt frameReaderTester) checkFrameEmpty(t *testing.T, hdr, body []byte) { + assertByteEquals(t, hdr, frt.emptyFrame[:frt.headerLenEmpty]) + assertByteEquals(t, body, frt.emptyFrame[frt.headerLenEmpty:]) } -func TestFrameReaderFullFrame(t *testing.T) { - r := newFrameReader(simpleHeader{}) - r.addChunk(kTestFrame) - hdr, body, err := r.process() +func (frt frameReaderTester) TestFrames(t *testing.T) { + r := newFrameReader(frt.details) + r.addChunk(frt.fullFrame) + hdr, body, err := r.next() assertNotError(t, err, "Couldn't read frame 1") - checkFrame(t, hdr, body) + frt.checkFrameFull(t, hdr, body) - r.addChunk(kTestFrame) - hdr, body, err = r.process() + r.addChunk(frt.emptyFrame) + hdr, body, err = r.next() assertNotError(t, err, "Couldn't read frame 2") - checkFrame(t, hdr, body) + frt.checkFrameEmpty(t, hdr, body) } -func TestFrameReaderTwoFrames(t *testing.T) { - r := newFrameReader(simpleHeader{}) - r.addChunk(kTestFrame) - r.addChunk(kTestFrame) - hdr, body, err := r.process() +func (frt frameReaderTester) TestTwoFrames(t *testing.T) { + r := newFrameReader(frt.details) + r.addChunk(frt.fullFrame) + r.addChunk(frt.fullFrame) + hdr, body, err := r.next() assertNotError(t, err, "Couldn't read frame 1") - checkFrame(t, hdr, body) + frt.checkFrameFull(t, hdr, body) - hdr, body, err = r.process() + hdr, body, err = r.next() assertNotError(t, err, "Couldn't read frame 2") - checkFrame(t, hdr, body) + frt.checkFrameFull(t, hdr, body) } -func TestFrameReaderTrickle(t *testing.T) { - r := newFrameReader(simpleHeader{}) +func (frt frameReaderTester) TestTrickle(t *testing.T) { + r := newFrameReader(frt.details) var hdr, body []byte var err error - for i := 0; i <= len(kTestFrame); i += 1 { - hdr, body, err = r.process() - if i < len(kTestFrame) { + for i := 0; i <= len(frt.fullFrame); i += 1 { + hdr, body, err = r.next() + if i < len(frt.fullFrame) { assertEquals(t, err, AlertWouldBlock) assertEquals(t, 0, len(hdr)) assertEquals(t, 0, len(body)) - r.addChunk(kTestFrame[i : i+1]) + r.addChunk(frt.fullFrame[i : i+1]) } } assertNil(t, err, "Error reading") - checkFrame(t, hdr, body) + frt.checkFrameFull(t, hdr, body) +} + +func (frt frameReaderTester) Run(t *testing.T) { + t.Run("frames", frt.TestFrames) + t.Run("two-frames", frt.TestTwoFrames) + t.Run("trickle", frt.TestTrickle) +} + +func TestFrameReader(t *testing.T) { + cases := map[string]frameReaderTester{ + "fixed": frameReaderTester{ + lastNBytesFraming{3, 2}, + 3, fixedFullFrame, + 3, fixedEmptyFrame, + }, + "variable": frameReaderTester{ + variableHeader{}, + 2, variableFullFrame, + 1, variableEmptyFrame, + }, + } + + for label, c := range cases { + t.Run(label, c.Run) + } } diff --git a/handshake-layer.go b/handshake-layer.go index 0ba434b..ae7f506 100644 --- a/handshake-layer.go +++ b/handshake-layer.go @@ -121,7 +121,7 @@ type HandshakeLayer struct { ctx *HandshakeContext // The handshake we are attached to nonblocking bool // Should we operate in nonblocking mode conn RecordLayer // Used for reading/writing records - frame *frameReader2 // The buffered frame reader + frame *frameReader // The buffered frame reader datagram bool // Is this DTLS? msgSeq uint32 // The DTLS message sequence number queued []*HandshakeMessage // In/out queue @@ -160,7 +160,7 @@ func NewHandshakeLayerTLS(c *HandshakeContext, r RecordLayer) *HandshakeLayer { h.ctx = c h.conn = r h.datagram = false - h.frame = newFrameReader2(lastNBytesFraming{handshakeHeaderLenTLS, 3}) + h.frame = newFrameReader(lastNBytesFraming{handshakeHeaderLenTLS, 3}) h.maxFragmentLen = maxFragmentLen return &h } @@ -170,7 +170,7 @@ func NewHandshakeLayerDTLS(c *HandshakeContext, r RecordLayer) *HandshakeLayer { h.ctx = c h.conn = r h.datagram = true - h.frame = newFrameReader2(lastNBytesFraming{handshakeHeaderLenDTLS, 3}) + h.frame = newFrameReader(lastNBytesFraming{handshakeHeaderLenDTLS, 3}) h.maxFragmentLen = initialMtu // Not quite right return &h } diff --git a/record-layer.go b/record-layer.go index 5b90df5..c8591be 100644 --- a/record-layer.go +++ b/record-layer.go @@ -89,7 +89,7 @@ type DefaultRecordLayer struct { direction Direction version uint16 // The current version number conn io.ReadWriter // The underlying connection - frame *frameReader2 // The buffered frame reader + frame *frameReader // The buffered frame reader nextData []byte // The next record to send cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" cachedError error // Error on the last record read @@ -122,7 +122,7 @@ func NewRecordLayerTLS(conn io.ReadWriter, dir Direction) *DefaultRecordLayer { r.label = "" r.direction = dir r.conn = conn - r.frame = newFrameReader2(lastNBytesFraming{recordHeaderLenTLS, 2}) + r.frame = newFrameReader(lastNBytesFraming{recordHeaderLenTLS, 2}) r.cipher = newCipherStateNull() r.version = tls10Version return &r @@ -133,7 +133,7 @@ func NewRecordLayerDTLS(conn io.ReadWriter, dir Direction) *DefaultRecordLayer { r.label = "" r.direction = dir r.conn = conn - r.frame = newFrameReader2(lastNBytesFraming{recordHeaderLenDTLS, 2}) + r.frame = newFrameReader(lastNBytesFraming{recordHeaderLenDTLS, 2}) r.cipher = newCipherStateNull() r.readCiphers = make(map[Epoch]*cipherState, 0) r.readCiphers[0] = r.cipher