From 53ff17009bb6565da12662b7b4c89bfcaa4d7060 Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Sun, 17 Mar 2024 21:29:32 +0000 Subject: [PATCH] Fix log state propagation --- common/ctstretch/ctstretch.go | 18 +++++++++--------- common/framing/framing.go | 10 ++++++---- riverrun.go | 6 +++--- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/common/ctstretch/ctstretch.go b/common/ctstretch/ctstretch.go index 057bab8..c7ac4a1 100644 --- a/common/ctstretch/ctstretch.go +++ b/common/ctstretch/ctstretch.go @@ -170,7 +170,7 @@ func BytesToUInt16(data []byte, startIDx, endIDx uint64) (uint16, error) { return binary.BigEndian.Uint16(data[startIDx:endIDx]), nil } -func ExpandBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, table16, table8 []uint64, stream cipher.Stream, tb int) error { +func ExpandBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, table16, table8 []uint64, stream cipher.Stream, tb int, logger log.Logger) error { if inputBlockBits != 8 && inputBlockBits != 16 { return fmt.Errorf("ctstretch/bit_manip: input bit block size must be 8 or 16") @@ -195,13 +195,13 @@ func ExpandBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, table1 outputBlockBytes := outputBlockBits / 8 if inputBlockBits == 16 && srcNBytes%2 == 1 { - err := ExpandBytes(src[0:srcNBytes-1], dst[0:uint64(srcNBytes-1)*outputBlockBytes/inputBlockBytes], inputBlockBits, outputBlockBits, table16, table8, stream, tb) + err := ExpandBytes(src[0:srcNBytes-1], dst[0:uint64(srcNBytes-1)*outputBlockBytes/inputBlockBytes], inputBlockBits, outputBlockBits, table16, table8, stream, tb, logger) if err != nil { return err } - return ExpandBytes(src[srcNBytes-1:], dst[uint64(srcNBytes-1)*outputBlockBytes/inputBlockBytes:], 8, outputBlockBits/2, table16, table8, stream, tb) + return ExpandBytes(src[srcNBytes-1:], dst[uint64(srcNBytes-1)*outputBlockBytes/inputBlockBytes:], 8, outputBlockBits/2, table16, table8, stream, tb, logger) } - log.Debugf("Expanding to %f, tb: %d", uint64(srcNBytes)*outputBlockBytes/inputBlockBytes, tb) + logger.Debugf("Expanding to %f, tb: %d", uint64(srcNBytes)*outputBlockBytes/inputBlockBytes, tb) var table *[]uint64 if inputBlockBits == 8 { @@ -246,10 +246,10 @@ func ExpandBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, table1 return nil } -func CompressBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, inversion16, inversion8 map[uint64]uint64, stream cipher.Stream, tb int) error { +func CompressBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, inversion16, inversion8 map[uint64]uint64, stream cipher.Stream, tb int, logger log.Logger) error { // XXX: tb is for tracing purposes. Remove before release. srcNBytes := len(src) // 1: 1074 2: 2 - log.Debugf("srcNBytes: %d, iBB: %d, oBB: %d, tb: %d", srcNBytes, inputBlockBits, outputBlockBits, tb) + logger.Debugf("srcNBytes: %d, iBB: %d, oBB: %d, tb: %d", srcNBytes, inputBlockBits, outputBlockBits, tb) if inputBlockBits%8 != 0 || inputBlockBits > 64 { return fmt.Errorf("ctstretch/bit_manip: input block size must be a multiple of 8 and less than 64") } @@ -271,16 +271,16 @@ func CompressBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, inve blocks := uint64(srcNBytes) / inputBlockBytes // 1: 134 2: 0 if (uint64(srcNBytes) % inputBlockBytes) != 0 { // 1: True (=2) 2: True (=2) if blocks == 0 { // 1: False // 2: True - return CompressBytes(src, dst, inputBlockBits/2, outputBlockBits/2, inversion16, inversion8, stream, tb) + return CompressBytes(src, dst, inputBlockBits/2, outputBlockBits/2, inversion16, inversion8, stream, tb, logger) } endSrc := blocks * inputBlockBytes // 1072 endDst := blocks * outputBlockBytes // 268 - err := CompressBytes(src[0:endSrc], dst[0:endDst], inputBlockBits, outputBlockBits, inversion16, inversion8, stream, tb) + err := CompressBytes(src[0:endSrc], dst[0:endDst], inputBlockBits, outputBlockBits, inversion16, inversion8, stream, tb, logger) if err != nil { return err } - return CompressBytes(src[endSrc:], dst[endDst:], inputBlockBits/2, outputBlockBits/2, inversion16, inversion8, stream, tb) + return CompressBytes(src[endSrc:], dst[endDst:], inputBlockBits/2, outputBlockBits/2, inversion16, inversion8, stream, tb, logger) } diff --git a/common/framing/framing.go b/common/framing/framing.go index 795215c..b21df65 100644 --- a/common/framing/framing.go +++ b/common/framing/framing.go @@ -156,6 +156,8 @@ type BaseDecoder struct { ReceiveBuffer *bytes.Buffer ReceiveDecodedBuffer *bytes.Buffer readBuffer []byte + + logger log.Logger } func (decoder *BaseDecoder) InitBuffers() { @@ -263,9 +265,9 @@ func (decoder *BaseDecoder) Decode(data []byte, frames *bytes.Buffer) (int, erro return 0, err } lengthMask := decoder.Drbg.NextBlock() - log.Debugf("length (raw): %d, length (mask): %d", length, lengthMask) + decoder.logger.Debugf("length (raw): %d, length (mask): %d", length, lengthMask) length ^= binary.BigEndian.Uint16(lengthMask) - log.Debugf("First nextLength: %d", length) + decoder.logger.Debugf("First nextLength: %d", length) if MaximumSegmentLength-int(decoder.LengthLength) < int(length) || decoder.MinPayloadLength > int(length) { // Per "Plaintext Recovery Attacks Against SSH" by // Martin R. Albrecht, Kenneth G. Paterson and Gaven J. Watson, @@ -277,11 +279,11 @@ func (decoder *BaseDecoder) Decode(data []byte, frames *bytes.Buffer) (int, erro // by pretending that the length was a random valid range as per // the countermeasure suggested by Denis Bider in section 6 of the // paper. - log.Debugf("Bad length") + decoder.logger.Debugf("Bad length") decoder.NextLengthInvalid = true length = uint16(csrand.IntRange(decoder.MinPayloadLength, MaximumSegmentLength-int(decoder.LengthLength))) } - log.Debugf("Out nextLength: %d", length) + decoder.logger.Debugf("Out nextLength: %d", length) decoder.NextLength = length } diff --git a/riverrun.go b/riverrun.go index c58f515..8a5f49f 100644 --- a/riverrun.go +++ b/riverrun.go @@ -228,7 +228,7 @@ func (encoder *riverrunEncoder) processLength(length uint16) ([]byte, error) { lengthBytes := make([]byte, f.LengthLength) binary.BigEndian.PutUint16(lengthBytes[:], length) lengthBytesEncoded := make([]byte, encoder.LengthLength) - err := ctstretch.ExpandBytes(lengthBytes[:], lengthBytesEncoded, encoder.compressedBlockBits, encoder.expandedBlockBits, encoder.table16, encoder.table8, encoder.writeStream, rand.Int()) + err := ctstretch.ExpandBytes(lengthBytes[:], lengthBytesEncoded, encoder.compressedBlockBits, encoder.expandedBlockBits, encoder.table16, encoder.table8, encoder.writeStream, rand.Int(), encoder.logger) return lengthBytesEncoded, err } @@ -237,7 +237,7 @@ func (encoder *riverrunEncoder) encode(frame, payload []byte) (n int, err error) expandedNBytes := int(ctstretch.ExpandedNBytes(uint64(len(payload)), encoder.compressedBlockBits, encoder.expandedBlockBits)) frameLen := encoder.LengthLength + expandedNBytes encoder.logger.Debugf("Encoding frame of length %d, with payload of length %d. TB: %d", frameLen, expandedNBytes, tb) - err = ctstretch.ExpandBytes(payload[:], frame, encoder.compressedBlockBits, encoder.expandedBlockBits, encoder.table16, encoder.table8, encoder.writeStream, tb) + err = ctstretch.ExpandBytes(payload[:], frame, encoder.compressedBlockBits, encoder.expandedBlockBits, encoder.table16, encoder.table8, encoder.writeStream, tb, encoder.logger) if err != nil { return 0, err } @@ -341,7 +341,7 @@ func (decoder *riverrunDecoder) decodePayload(frames *bytes.Buffer) ([]byte, err } func (decoder *riverrunDecoder) compressBytes(raw, res []byte) error { - return ctstretch.CompressBytes(raw, res, decoder.expandedBlockBits, decoder.compressedBlockBits, decoder.revTable16, decoder.revTable8, decoder.readStream, rand.Int()) + return ctstretch.CompressBytes(raw, res, decoder.expandedBlockBits, decoder.compressedBlockBits, decoder.revTable16, decoder.revTable8, decoder.readStream, rand.Int(), decoder.logger) } func (rr *Conn) nextLength() int {