Skip to content

Commit

Permalink
Fix log state propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaokangwang committed Mar 17, 2024
1 parent 8447e4f commit 53ff170
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
18 changes: 9 additions & 9 deletions common/ctstretch/ctstretch.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func BytesToUInt16(data []byte, startIDx, endIDx uint64) (uint16, error) {
return binary.BigEndian.Uint16(data[startIDx:endIDx]), nil
}

func ExpandBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, table16, table8 []uint64, stream cipher.Stream, tb int) error {
func ExpandBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, table16, table8 []uint64, stream cipher.Stream, tb int, logger log.Logger) error {

if inputBlockBits != 8 && inputBlockBits != 16 {
return fmt.Errorf("ctstretch/bit_manip: input bit block size must be 8 or 16")
Expand All @@ -195,13 +195,13 @@ func ExpandBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, table1
outputBlockBytes := outputBlockBits / 8

if inputBlockBits == 16 && srcNBytes%2 == 1 {
err := ExpandBytes(src[0:srcNBytes-1], dst[0:uint64(srcNBytes-1)*outputBlockBytes/inputBlockBytes], inputBlockBits, outputBlockBits, table16, table8, stream, tb)
err := ExpandBytes(src[0:srcNBytes-1], dst[0:uint64(srcNBytes-1)*outputBlockBytes/inputBlockBytes], inputBlockBits, outputBlockBits, table16, table8, stream, tb, logger)
if err != nil {
return err
}
return ExpandBytes(src[srcNBytes-1:], dst[uint64(srcNBytes-1)*outputBlockBytes/inputBlockBytes:], 8, outputBlockBits/2, table16, table8, stream, tb)
return ExpandBytes(src[srcNBytes-1:], dst[uint64(srcNBytes-1)*outputBlockBytes/inputBlockBytes:], 8, outputBlockBits/2, table16, table8, stream, tb, logger)
}
log.Debugf("Expanding to %f, tb: %d", uint64(srcNBytes)*outputBlockBytes/inputBlockBytes, tb)
logger.Debugf("Expanding to %f, tb: %d", uint64(srcNBytes)*outputBlockBytes/inputBlockBytes, tb)

var table *[]uint64
if inputBlockBits == 8 {
Expand Down Expand Up @@ -246,10 +246,10 @@ func ExpandBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, table1
return nil
}

func CompressBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, inversion16, inversion8 map[uint64]uint64, stream cipher.Stream, tb int) error {
func CompressBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, inversion16, inversion8 map[uint64]uint64, stream cipher.Stream, tb int, logger log.Logger) error {
// XXX: tb is for tracing purposes. Remove before release.
srcNBytes := len(src) // 1: 1074 2: 2
log.Debugf("srcNBytes: %d, iBB: %d, oBB: %d, tb: %d", srcNBytes, inputBlockBits, outputBlockBits, tb)
logger.Debugf("srcNBytes: %d, iBB: %d, oBB: %d, tb: %d", srcNBytes, inputBlockBits, outputBlockBits, tb)
if inputBlockBits%8 != 0 || inputBlockBits > 64 {
return fmt.Errorf("ctstretch/bit_manip: input block size must be a multiple of 8 and less than 64")
}
Expand All @@ -271,16 +271,16 @@ func CompressBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, inve
blocks := uint64(srcNBytes) / inputBlockBytes // 1: 134 2: 0
if (uint64(srcNBytes) % inputBlockBytes) != 0 { // 1: True (=2) 2: True (=2)
if blocks == 0 { // 1: False // 2: True
return CompressBytes(src, dst, inputBlockBits/2, outputBlockBits/2, inversion16, inversion8, stream, tb)
return CompressBytes(src, dst, inputBlockBits/2, outputBlockBits/2, inversion16, inversion8, stream, tb, logger)
}

endSrc := blocks * inputBlockBytes // 1072
endDst := blocks * outputBlockBytes // 268
err := CompressBytes(src[0:endSrc], dst[0:endDst], inputBlockBits, outputBlockBits, inversion16, inversion8, stream, tb)
err := CompressBytes(src[0:endSrc], dst[0:endDst], inputBlockBits, outputBlockBits, inversion16, inversion8, stream, tb, logger)
if err != nil {
return err
}
return CompressBytes(src[endSrc:], dst[endDst:], inputBlockBits/2, outputBlockBits/2, inversion16, inversion8, stream, tb)
return CompressBytes(src[endSrc:], dst[endDst:], inputBlockBits/2, outputBlockBits/2, inversion16, inversion8, stream, tb, logger)

}

Expand Down
10 changes: 6 additions & 4 deletions common/framing/framing.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ type BaseDecoder struct {
ReceiveBuffer *bytes.Buffer
ReceiveDecodedBuffer *bytes.Buffer
readBuffer []byte

logger log.Logger
}

func (decoder *BaseDecoder) InitBuffers() {
Expand Down Expand Up @@ -263,9 +265,9 @@ func (decoder *BaseDecoder) Decode(data []byte, frames *bytes.Buffer) (int, erro
return 0, err
}
lengthMask := decoder.Drbg.NextBlock()
log.Debugf("length (raw): %d, length (mask): %d", length, lengthMask)
decoder.logger.Debugf("length (raw): %d, length (mask): %d", length, lengthMask)
length ^= binary.BigEndian.Uint16(lengthMask)
log.Debugf("First nextLength: %d", length)
decoder.logger.Debugf("First nextLength: %d", length)
if MaximumSegmentLength-int(decoder.LengthLength) < int(length) || decoder.MinPayloadLength > int(length) {
// Per "Plaintext Recovery Attacks Against SSH" by
// Martin R. Albrecht, Kenneth G. Paterson and Gaven J. Watson,
Expand All @@ -277,11 +279,11 @@ func (decoder *BaseDecoder) Decode(data []byte, frames *bytes.Buffer) (int, erro
// by pretending that the length was a random valid range as per
// the countermeasure suggested by Denis Bider in section 6 of the
// paper.
log.Debugf("Bad length")
decoder.logger.Debugf("Bad length")
decoder.NextLengthInvalid = true
length = uint16(csrand.IntRange(decoder.MinPayloadLength, MaximumSegmentLength-int(decoder.LengthLength)))
}
log.Debugf("Out nextLength: %d", length)
decoder.logger.Debugf("Out nextLength: %d", length)
decoder.NextLength = length
}

Expand Down
6 changes: 3 additions & 3 deletions riverrun.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func (encoder *riverrunEncoder) processLength(length uint16) ([]byte, error) {
lengthBytes := make([]byte, f.LengthLength)
binary.BigEndian.PutUint16(lengthBytes[:], length)
lengthBytesEncoded := make([]byte, encoder.LengthLength)
err := ctstretch.ExpandBytes(lengthBytes[:], lengthBytesEncoded, encoder.compressedBlockBits, encoder.expandedBlockBits, encoder.table16, encoder.table8, encoder.writeStream, rand.Int())
err := ctstretch.ExpandBytes(lengthBytes[:], lengthBytesEncoded, encoder.compressedBlockBits, encoder.expandedBlockBits, encoder.table16, encoder.table8, encoder.writeStream, rand.Int(), encoder.logger)
return lengthBytesEncoded, err
}

Expand All @@ -237,7 +237,7 @@ func (encoder *riverrunEncoder) encode(frame, payload []byte) (n int, err error)
expandedNBytes := int(ctstretch.ExpandedNBytes(uint64(len(payload)), encoder.compressedBlockBits, encoder.expandedBlockBits))
frameLen := encoder.LengthLength + expandedNBytes
encoder.logger.Debugf("Encoding frame of length %d, with payload of length %d. TB: %d", frameLen, expandedNBytes, tb)
err = ctstretch.ExpandBytes(payload[:], frame, encoder.compressedBlockBits, encoder.expandedBlockBits, encoder.table16, encoder.table8, encoder.writeStream, tb)
err = ctstretch.ExpandBytes(payload[:], frame, encoder.compressedBlockBits, encoder.expandedBlockBits, encoder.table16, encoder.table8, encoder.writeStream, tb, encoder.logger)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -341,7 +341,7 @@ func (decoder *riverrunDecoder) decodePayload(frames *bytes.Buffer) ([]byte, err
}

func (decoder *riverrunDecoder) compressBytes(raw, res []byte) error {
return ctstretch.CompressBytes(raw, res, decoder.expandedBlockBits, decoder.compressedBlockBits, decoder.revTable16, decoder.revTable8, decoder.readStream, rand.Int())
return ctstretch.CompressBytes(raw, res, decoder.expandedBlockBits, decoder.compressedBlockBits, decoder.revTable16, decoder.revTable8, decoder.readStream, rand.Int(), decoder.logger)
}

func (rr *Conn) nextLength() int {
Expand Down

0 comments on commit 53ff170

Please sign in to comment.