From d912d5ac7c2c114a02af10986994da37391b3cdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torbj=C3=B6rn=20Einarson?= Date: Wed, 21 Feb 2024 17:08:34 +0100 Subject: [PATCH] feat: added counter methods to bits.Reader --- CHANGELOG.md | 1 + bits/reader.go | 46 +++++++++++++++++++++++++++++++++------------ bits/reader_test.go | 34 ++++++++++++++++++++------------- 3 files changed, 56 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 501ae5d..2728e6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - InitSegment.TweakSingleTrakLive changes an init segment to fit live streaming - Made bits.Mask() function public +- New counter methods added to bits.Reader ### Changed diff --git a/bits/reader.go b/bits/reader.go index 725b338..a9cf4af 100644 --- a/bits/reader.go +++ b/bits/reader.go @@ -10,10 +10,11 @@ import ( // Reader is a bit reader that stops reading at first error and stores it. // First error can be fetched usiin AccError(). type Reader struct { - rd io.Reader - err error - nrBits int // current number of bits - value uint // current accumulated value + rd io.Reader + err error + n int // current number of bits + value uint // current accumulated value + pos int // current position in reader (in bytes) } // AccError - accumulated error is first error that occurred @@ -24,7 +25,8 @@ func (r *Reader) AccError() error { // NewReader return a new Reader that accumulates errors. func NewReader(rd io.Reader) *Reader { return &Reader{ - rd: rd, + rd: rd, + pos: -1, } } @@ -34,7 +36,7 @@ func (r *Reader) Read(n int) uint { return 0 } - for r.nrBits < n { + for r.n < n { r.value <<= 8 var newByte uint8 err := binary.Read(r.rd, binary.BigEndian, &newByte) @@ -42,14 +44,15 @@ func (r *Reader) Read(n int) uint { r.err = err return 0 } + r.pos++ r.value |= uint(newByte) - r.nrBits += 8 + r.n += 8 } - value := r.value >> uint(r.nrBits-n) + value := r.value >> uint(r.n-n) - r.nrBits -= n - r.value &= Mask(r.nrBits) + r.n -= n + r.value &= Mask(r.n) return value } @@ -78,8 +81,8 @@ func (r *Reader) ReadRemainingBytes() []byte { if r.err != nil { return nil } - if r.nrBits != 0 { - r.err = fmt.Errorf("%d bit instead of byte alignment when reading remaining bytes", r.nrBits) + if r.n != 0 { + r.err = fmt.Errorf("%d bit instead of byte alignment when reading remaining bytes", r.n) return nil } rest, err := ioutil.ReadAll(r.rd) @@ -89,3 +92,22 @@ func (r *Reader) ReadRemainingBytes() []byte { } return rest } + +// NrBytesRead returns how many bytes read into parser. +func (r *Reader) NrBytesRead() int { + return r.pos + 1 // Starts at -1 +} + +// NrBitsRead returns total number of bits read into parser. +func (r *Reader) NrBitsRead() int { + nrBits := r.NrBytesRead() * 8 + if r.NrBitsReadInCurrentByte() != 8 { + nrBits += r.NrBitsReadInCurrentByte() - 8 + } + return nrBits +} + +// NrBitsReadInCurrentByte returns number of bits read in current byte. +func (r *Reader) NrBitsReadInCurrentByte() int { + return 8 - r.n +} diff --git a/bits/reader_test.go b/bits/reader_test.go index fb000bb..c5ef320 100644 --- a/bits/reader_test.go +++ b/bits/reader_test.go @@ -15,21 +15,29 @@ func TestAccErrReader(t *testing.T) { reader := bits.NewReader(rd) cases := []struct { - n int - want uint + readNrBits int + want uint + nrBytesRead int + nrBitsRead int }{ - {2, 3}, // 11 - {3, 7}, // 111 - {5, 28}, // 11100 - {3, 1}, // 001 - {3, 7}, // 111 + {2, 3, 1, 2}, // 11 + {3, 7, 1, 5}, // 111 + {5, 28, 2, 10}, // 11100 + {3, 1, 2, 13}, // 001 + {3, 7, 2, 16}, // 111 } for _, tc := range cases { - got := reader.Read(tc.n) + got := reader.Read(tc.readNrBits) if got != tc.want { - t.Errorf("Read(%d)=%b, want=%b", tc.n, got, tc.want) + t.Errorf("Read(%d)=%b, want=%b", tc.readNrBits, got, tc.want) + } + if reader.NrBytesRead() != tc.nrBytesRead { + t.Errorf("NrBytesRead()=%d, want=%d", reader.NrBytesRead(), tc.nrBytesRead) + } + if reader.NrBitsRead() != tc.nrBitsRead { + t.Errorf("NrBitsRead()=%d, want=%d", reader.NrBitsRead(), tc.nrBitsRead) } } err := reader.AccError() @@ -109,8 +117,8 @@ func TestAccErrReaderSigned(t *testing.T) { reader := bits.NewReader(rd) cases := []struct { - n int - want int + readNrBits int + want int }{ {2, -1}, // 11 {3, -1}, // 111 @@ -120,10 +128,10 @@ func TestAccErrReaderSigned(t *testing.T) { } for _, tc := range cases { - got := reader.ReadSigned(tc.n) + got := reader.ReadSigned(tc.readNrBits) if got != tc.want { - t.Errorf("Read(%d)=%b, want=%b", tc.n, got, tc.want) + t.Errorf("Read(%d)=%b, want=%b", tc.readNrBits, got, tc.want) } } err := reader.AccError()