From 502aabac0f6f89b2e1973585e463d0fd910a8982 Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Fri, 15 Mar 2024 21:36:38 +0000 Subject: [PATCH] Add riverrun code base from https://github.com/RACECAR-GU/obfsX --- common/csrand/csrand.go | 101 ++++++++ common/ctstretch/ctstretch.go | 328 +++++++++++++++++++++++ common/ctstretch/ctstretch_test.go | 76 ++++++ common/drbg/hash_drbg.go | 148 +++++++++++ common/framing/framing.go | 322 +++++++++++++++++++++++ common/log/log.go | 35 +++ go.mod | 8 + go.sum | 51 ++++ riverrun.go | 400 +++++++++++++++++++++++++++++ 9 files changed, 1469 insertions(+) create mode 100644 common/csrand/csrand.go create mode 100644 common/ctstretch/ctstretch.go create mode 100644 common/ctstretch/ctstretch_test.go create mode 100644 common/drbg/hash_drbg.go create mode 100644 common/framing/framing.go create mode 100644 common/log/log.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 riverrun.go diff --git a/common/csrand/csrand.go b/common/csrand/csrand.go new file mode 100644 index 0000000..157540b --- /dev/null +++ b/common/csrand/csrand.go @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2014, Yawning Angel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +// Package csrand implements the math/rand interface over crypto/rand, along +// with some utility functions for common random number/byte related tasks. +// +// Not all of the convinience routines are replicated, only those that are +// immediately useful. The Rand variable provides access to the full math/rand +// API. +package csrand // import "github.com/RACECAR-GU/obfsX/common/csrand" + +import ( + cryptRand "crypto/rand" + "encoding/binary" + "fmt" + "io" + "math/rand" +) + +var ( + csRandSourceInstance csRandSource + + // Rand is a math/rand instance backed by crypto/rand CSPRNG. + Rand = rand.New(csRandSourceInstance) +) + +type csRandSource struct { + // This does not keep any state as it is backed by crypto/rand. +} + +func (r csRandSource) Int63() int64 { + var src [8]byte + if err := Bytes(src[:]); err != nil { + panic(err) + } + val := binary.BigEndian.Uint64(src[:]) + val &= (1<<63 - 1) + + return int64(val) +} + +func (r csRandSource) Seed(seed int64) { + // No-op. +} + +// Intn returns, as a int, a pseudo random number in [0, n). +func Intn(n int) int { + return Rand.Intn(n) +} + +// Float64 returns, as a float64, a pesudo random number in [0.0,1.0). +func Float64() float64 { + return Rand.Float64() +} + +// IntRange returns a uniformly distributed int [min, max]. +func IntRange(min, max int) int { + if max < min { + panic(fmt.Sprintf("IntRange: min > max (%d, %d)", min, max)) + } + + r := (max + 1) - min + ret := Rand.Intn(r) + return ret + min +} + +// Bytes fills the slice with random data. +func Bytes(buf []byte) error { + if _, err := io.ReadFull(cryptRand.Reader, buf); err != nil { + return err + } + + return nil +} + +// Reader is a alias of rand.Reader. +var Reader = cryptRand.Reader diff --git a/common/ctstretch/ctstretch.go b/common/ctstretch/ctstretch.go new file mode 100644 index 0000000..057bab8 --- /dev/null +++ b/common/ctstretch/ctstretch.go @@ -0,0 +1,328 @@ +package ctstretch + +import ( + "crypto/cipher" + "encoding/binary" + "fmt" + "math" + "unsafe" + + "github.com/v2fly/riverrun/common/log" +) + +// Swaps bits i and j in data. Bit 0 is the first bit of data[0]. +func BitSwap(data []byte, i, j uint64) error { + + if i == j { + return nil + } + + numBits := uint64(len(data) * 8) + if i >= numBits || j >= numBits { + return fmt.Errorf("ctstretch/bit_manip: index out of bounds") + } + + var iByte *byte = &data[i/8] + var jByte *byte = &data[j/8] + var iBitIdx uint64 = i % 8 + var jBitIdx uint64 = j % 8 + + // If we are swapping bits a and b, the least-sig bit of c now contains + // a XOR b + var c byte = ((*iByte >> iBitIdx) & byte(1)) ^ ((*jByte >> jBitIdx) & byte(1)) + + *iByte = *iByte ^ (c << iBitIdx) + *jByte = *jByte ^ (c << jBitIdx) + return nil +} + +func UniformSample(a, b uint64, stream cipher.Stream) (uint64, error) { + var rnge uint64 + if a >= b { + return rnge, fmt.Errorf("ctstretch/bit_manip: invalid range") + } + + rnge = (b - a + 1) + + var z uint64 = 0 + var r uint64 = 0 + zBytes := (*[unsafe.Sizeof(z)]byte)(unsafe.Pointer(&z))[:] + rBytes := (*[unsafe.Sizeof(r)]byte)(unsafe.Pointer(&r))[:] + + stream.XORKeyStream(rBytes, zBytes) + + for cont := true; cont; cont = (r >= (math.MaxUint64 - (math.MaxUint64 % rnge))) { + stream.XORKeyStream(rBytes, zBytes) + } + + return a + (r % rnge), nil +} + +func BitShuffle(data []byte, rng cipher.Stream, rev bool) error { + numBits := uint64(len(data) * 8) + + shuffleIndices := make([]uint64, numBits-1) + var err error + for idx := uint64(0); idx < (numBits - 1); idx = idx + 1 { + shuffleIndices[idx], err = UniformSample(idx, numBits-1, rng) + if err != nil { + return err + } + } + + for idx := uint64(0); idx < (numBits - 1); idx = idx + 1 { + + kdx := uint64(0) + + if rev { + kdx = (numBits - 2) - idx + } else { + kdx = idx + } + + jdx := shuffleIndices[kdx] + err = BitSwap(data, kdx, jdx) + if err != nil { + return err + } + } + return nil +} + +func PrintBits(data []byte) { + for _, v := range data { + fmt.Printf("%08b\n", v) + } +} + +// Bias of 0.8 means 80% probability of outputting 0 +func SampleBiasedString(numBits uint64, bias float64, stream cipher.Stream) (uint64, error) { + var r uint64 + if numBits > 64 { + return r, fmt.Errorf("ctstretch/bit_manip: numBits out of range") + } + + r = uint64(0) + + for idx := uint64(0); idx < numBits; idx++ { + // Simulate a biased coin flip + sample, err := UniformSample(0, math.MaxUint64-1, stream) + if err != nil { + return r, err + } + x := float64(sample) / float64(math.MaxUint64-1) + b := uint64(0) + if x >= bias { + b++ + } + + r ^= (b << idx) + } + + return r, nil +} + +func SampleBiasedStrings(numBits, n uint64, bias float64, stream cipher.Stream) ([]uint64, error) { + vals := make([]uint64, n) + m := make(map[uint64]bool) + var err error + for idx := uint64(0); idx < n; idx += 1 { + + s := uint64(0) + haveKey := true + + for haveKey == true { + s, err = SampleBiasedString(numBits, bias, stream) + if err != nil { + return nil, err + } + _, haveKey = m[s] + } + + vals[idx] = s + m[s] = true + } + + return vals, nil +} + +func InvertTable(vals []uint64) map[uint64]uint64 { + m := make(map[uint64]uint64) + + for idx, val := range vals { + m[val] = uint64(idx) + } + + return m +} + +func BytesToUInt16(data []byte, startIDx, endIDx uint64) (uint16, error) { + if endIDx <= startIDx || (endIDx-startIDx) > 3 { + var errVal uint16 + return errVal, fmt.Errorf("ctstretch/bit_manip: invalid range") + } + + r := (endIDx - startIDx) + + if r == 1 { + return uint16(data[startIDx]), nil + } + return binary.BigEndian.Uint16(data[startIDx:endIDx]), nil +} + +func ExpandBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, table16, table8 []uint64, stream cipher.Stream, tb int) error { + + if inputBlockBits != 8 && inputBlockBits != 16 { + return fmt.Errorf("ctstretch/bit_manip: input bit block size must be 8 or 16") + } + if outputBlockBits%8 != 0 || outputBlockBits > 64 || outputBlockBits == 0 { + return fmt.Errorf("ctstretch/bit_manip: output block size must be a multiple of 8, less than or equal to 64, and greater than 0") + } + + srcNBytes := len(src) + + if srcNBytes == 0 { + return nil + } + + expansionFactor := float64(outputBlockBits) / float64(inputBlockBits) + + if float64(len(dst))/float64(srcNBytes) < expansionFactor { + return fmt.Errorf("ctstretch/bit_manip: dst has insufficient size") + } + + inputBlockBytes := inputBlockBits / 8 + outputBlockBytes := outputBlockBits / 8 + + if inputBlockBits == 16 && srcNBytes%2 == 1 { + err := ExpandBytes(src[0:srcNBytes-1], dst[0:uint64(srcNBytes-1)*outputBlockBytes/inputBlockBytes], inputBlockBits, outputBlockBits, table16, table8, stream, tb) + if err != nil { + return err + } + return ExpandBytes(src[srcNBytes-1:], dst[uint64(srcNBytes-1)*outputBlockBytes/inputBlockBytes:], 8, outputBlockBits/2, table16, table8, stream, tb) + } + log.Debugf("Expanding to %f, tb: %d", uint64(srcNBytes)*outputBlockBytes/inputBlockBytes, tb) + + var table *[]uint64 + if inputBlockBits == 8 { + table = &table8 + } else { + table = &table16 + } + + inputIdx := uint64(0) + outputIdx := uint64(0) + + for ; inputIdx < uint64(srcNBytes); inputIdx = inputIdx + inputBlockBytes { + x, err := BytesToUInt16(src, inputIdx, inputIdx+inputBlockBytes) + if err != nil { + return err + } + tableVal := (*table)[x] + // yuck :( no variable length casts in go. + switch outputBlockBytes { + case 2: + copy(dst[outputIdx:outputIdx+outputBlockBytes], (*[2]byte)(unsafe.Pointer(&tableVal))[:]) + case 3: + copy(dst[outputIdx:outputIdx+outputBlockBytes], (*[3]byte)(unsafe.Pointer(&tableVal))[:]) + case 4: + copy(dst[outputIdx:outputIdx+outputBlockBytes], (*[4]byte)(unsafe.Pointer(&tableVal))[:]) + case 5: + copy(dst[outputIdx:outputIdx+outputBlockBytes], (*[5]byte)(unsafe.Pointer(&tableVal))[:]) + case 6: + copy(dst[outputIdx:outputIdx+outputBlockBytes], (*[6]byte)(unsafe.Pointer(&tableVal))[:]) + case 7: + copy(dst[outputIdx:outputIdx+outputBlockBytes], (*[7]byte)(unsafe.Pointer(&tableVal))[:]) + case 8: + copy(dst[outputIdx:outputIdx+outputBlockBytes], (*[8]byte)(unsafe.Pointer(&tableVal))[:]) + } + + err = BitShuffle(dst[outputIdx:outputIdx+outputBlockBytes], stream, false) + if err != nil { + return err + } + outputIdx += outputBlockBytes + } + return nil +} + +func CompressBytes(src, dst []byte, inputBlockBits, outputBlockBits uint64, inversion16, inversion8 map[uint64]uint64, stream cipher.Stream, tb int) error { + // XXX: tb is for tracing purposes. Remove before release. + srcNBytes := len(src) // 1: 1074 2: 2 + log.Debugf("srcNBytes: %d, iBB: %d, oBB: %d, tb: %d", srcNBytes, inputBlockBits, outputBlockBits, tb) + if inputBlockBits%8 != 0 || inputBlockBits > 64 { + return fmt.Errorf("ctstretch/bit_manip: input block size must be a multiple of 8 and less than 64") + } + if outputBlockBits != 8 && outputBlockBits != 16 { + return fmt.Errorf("ctstretch/bit_manip: output bit block size must be 8 or 16, currently is %d, with input block size at %d, and len(src) %d. Traceback id: %d", outputBlockBits, inputBlockBits, srcNBytes, tb) + } + + // 4 output bits, 16 input bits, 3 total bytes + // Previous call had 8 output bits, 32 input bits, 108 + 3 bytes + // 1 output byte 4 input bytes + + if float64(len(dst))/float64(srcNBytes) < float64(outputBlockBits)/float64(inputBlockBits) { + return fmt.Errorf("ctstretch/bit_manip: dst has insufficient size") + } + + inputBlockBytes := inputBlockBits / 8 // 1: 8 2: 4 + outputBlockBytes := outputBlockBits / 8 // 1: 2 2: 1 + + blocks := uint64(srcNBytes) / inputBlockBytes // 1: 134 2: 0 + if (uint64(srcNBytes) % inputBlockBytes) != 0 { // 1: True (=2) 2: True (=2) + if blocks == 0 { // 1: False // 2: True + return CompressBytes(src, dst, inputBlockBits/2, outputBlockBits/2, inversion16, inversion8, stream, tb) + } + + endSrc := blocks * inputBlockBytes // 1072 + endDst := blocks * outputBlockBytes // 268 + err := CompressBytes(src[0:endSrc], dst[0:endDst], inputBlockBits, outputBlockBits, inversion16, inversion8, stream, tb) + if err != nil { + return err + } + return CompressBytes(src[endSrc:], dst[endDst:], inputBlockBits/2, outputBlockBits/2, inversion16, inversion8, stream, tb) + + } + + inputIdx := uint64(0) + outputIdx := uint64(0) + + var inversion *map[uint64]uint64 + if outputBlockBits == 8 { + inversion = &inversion8 + } else { + inversion = &inversion16 + } + for ; inputIdx < uint64(srcNBytes); inputIdx = inputIdx + inputBlockBytes { + err := BitShuffle(src[inputIdx:inputIdx+inputBlockBytes], stream, true) + if err != nil { + return err + } + + var x, y uint64 + x = 0 + y = 0 + copy((*[unsafe.Sizeof(x)]byte)(unsafe.Pointer(&x))[:], + src[inputIdx:inputIdx+inputBlockBytes]) + y = (*inversion)[x] + if outputBlockBytes == 1 { + z := uint8(y) + dst[outputIdx] = z + } else { + binary.BigEndian.PutUint16(dst[outputIdx:outputIdx+outputBlockBytes], uint16(y)) + } + outputIdx += outputBlockBytes + } + return nil +} + +func ExpandedNBytes(srcLen, inputBlockBits, outputBlockBits uint64) uint64 { + return srcLen * (outputBlockBits / inputBlockBits) +} + +func CompressedNBytes(expandedLen, inputBlockBits, outputBlockBits uint64) uint64 { + return uint64(math.Ceil(float64(expandedLen) * (float64(outputBlockBits) / float64(inputBlockBits)))) +} +func CompressedNBytes_floor(expandedLen, inputBlockBits, outputBlockBits uint64) uint64 { + return uint64(math.Floor(float64(expandedLen) * (float64(outputBlockBits) / float64(inputBlockBits)))) +} diff --git a/common/ctstretch/ctstretch_test.go b/common/ctstretch/ctstretch_test.go new file mode 100644 index 0000000..73d7177 --- /dev/null +++ b/common/ctstretch/ctstretch_test.go @@ -0,0 +1,76 @@ +package ctstretch + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" +) + +func main() { + key := make([]byte, 16) + rand.Read(key) + + block, err := aes.NewCipher(key) + if err != nil { + panic(err.Error()) + } + + iv := make([]byte, block.BlockSize()) + rand.Read(iv) + + streamClient := cipher.NewCTR(block, iv) + streamServer := cipher.NewCTR(block, iv) + + bias := float64(0.55) + msgLens := []uint64{1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17} + + outputNBits8 := []uint64{16, 24, 32, 40, 48, 56, 64} + outputNBits16 := []uint64{32, 48, 64} + + for _, msgLen := range msgLens { + for _, outputNBits := range outputNBits8 { + runTest(msgLen, 8, outputNBits, bias, streamClient, streamServer) + } + + for _, outputNBits := range outputNBits16 { + runTest(msgLen, 16, outputNBits, bias, streamClient, streamServer) + } + } +} + +func runTest(msgNBytes, inputBlockBits, outputBlockBits uint64, bias float64, + streamClient, streamServer cipher.Stream) { + + clientTable16 := ctstretch.SampleBiasedStrings(outputBlockBits, 65536, bias, streamClient) + serverTable16 := ctstretch.InvertTable(ctstretch.SampleBiasedStrings(outputBlockBits, 65536, bias, streamServer)) + + var outputBlockBits8 uint64 + if inputBlockBits == 8 { + outputBlockBits8 = outputBlockBits + } else { + outputBlockBits8 = outputBlockBits / 2 + } + + clientTable8 := ctstretch.SampleBiasedStrings(outputBlockBits8, 256, bias, streamClient) + serverTable8 := ctstretch.InvertTable(ctstretch.SampleBiasedStrings(outputBlockBits8, 256, bias, streamServer)) + + msg := make([]byte, msgNBytes) + expandedNBytes := ctstretch.ExpandedNBytes(msgNBytes, inputBlockBits, outputBlockBits) + compressedNBytes := ctstretch.CompressedNBytes(expandedNBytes, outputBlockBits, inputBlockBits) + + expanded := make([]byte, expandedNBytes) + rand.Read(msg) + compressed := make([]byte, compressedNBytes) + + ctstretch.ExpandBytes(msg[:], expanded, inputBlockBits, outputBlockBits, clientTable16, clientTable8, streamClient) + ctstretch.CompressBytes(expanded, compressed, outputBlockBits, inputBlockBits, serverTable16, serverTable8, streamServer) + + if bytes.Equal(msg, compressed) { + fmt.Println("Pass:", msgNBytes, inputBlockBits, outputBlockBits) + } else { + fmt.Println(msg, compressed) + fmt.Println("Fail:", msgNBytes, inputBlockBits, outputBlockBits) + } +} diff --git a/common/drbg/hash_drbg.go b/common/drbg/hash_drbg.go new file mode 100644 index 0000000..61ca3cb --- /dev/null +++ b/common/drbg/hash_drbg.go @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2014, Yawning Angel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +// Package drbg implements a minimalistic DRBG based off SipHash-2-4 in OFB +// mode. +package drbg // import "github.com/RACECAR-GU/obfsX/common/drbg" + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "hash" + + "github.com/dchest/siphash" + "github.com/v2fly/riverrun/common/csrand" +) + +// Size is the length of the HashDrbg output. +const Size = siphash.Size + +// SeedLength is the length of the HashDrbg seed. +const SeedLength = 16 + Size + +// Seed is the initial state for a HashDrbg. It consists of a SipHash-2-4 +// key, and 8 bytes of initial data. +type Seed [SeedLength]byte + +// Bytes returns a pointer to the raw HashDrbg seed. +func (seed *Seed) Bytes() *[SeedLength]byte { + return (*[SeedLength]byte)(seed) +} + +// Hex returns the hexdecimal representation of the seed. +func (seed *Seed) Hex() string { + return hex.EncodeToString(seed.Bytes()[:]) +} + +// NewSeed returns a Seed initialized with the runtime CSPRNG. +func NewSeed() (seed *Seed, err error) { + seed = new(Seed) + if err = csrand.Bytes(seed.Bytes()[:]); err != nil { + return nil, err + } + + return +} + +// SeedFromBytes creates a Seed from the raw bytes, truncating to SeedLength as +// appropriate. +func SeedFromBytes(src []byte) (seed *Seed, err error) { + if len(src) < SeedLength { + return nil, InvalidSeedLengthError(len(src)) + } + + seed = new(Seed) + copy(seed.Bytes()[:], src) + + return +} + +// SeedFromHex creates a Seed from the hexdecimal representation, truncating to +// SeedLength as appropriate. +func SeedFromHex(encoded string) (seed *Seed, err error) { + var raw []byte + if raw, err = hex.DecodeString(encoded); err != nil { + return nil, err + } + + return SeedFromBytes(raw) +} + +// InvalidSeedLengthError is the error returned when the seed provided to the +// DRBG is an invalid length. +type InvalidSeedLengthError int + +func (e InvalidSeedLengthError) Error() string { + return fmt.Sprintf("invalid seed length: %d", int(e)) +} + +// HashDrbg is a CSDRBG based off of SipHash-2-4 in OFB mode. +type HashDrbg struct { + sip hash.Hash64 + ofb [Size]byte +} + +// NewHashDrbg makes a HashDrbg instance based off an optional seed. The seed +// is truncated to SeedLength. +func NewHashDrbg(seed *Seed) (*HashDrbg, error) { + drbg := new(HashDrbg) + if seed == nil { + var err error + if seed, err = NewSeed(); err != nil { + return nil, err + } + } + drbg.sip = siphash.New(seed.Bytes()[:16]) + copy(drbg.ofb[:], seed.Bytes()[16:]) + + return drbg, nil +} + +// Int63 returns a uniformly distributed random integer [0, 1 << 63). +func (drbg *HashDrbg) Int63() int64 { + block := drbg.NextBlock() + ret := binary.BigEndian.Uint64(block) + ret &= (1<<63 - 1) + + return int64(ret) +} + +// Seed does nothing, call NewHashDrbg if you want to reseed. +func (drbg *HashDrbg) Seed(seed int64) { + // No-op. +} + +// NextBlock returns the next 8 byte DRBG block. +func (drbg *HashDrbg) NextBlock() []byte { + _, _ = drbg.sip.Write(drbg.ofb[:]) + copy(drbg.ofb[:], drbg.sip.Sum(nil)) + + ret := make([]byte, Size) + copy(ret, drbg.ofb[:]) + return ret +} diff --git a/common/framing/framing.go b/common/framing/framing.go new file mode 100644 index 0000000..795215c --- /dev/null +++ b/common/framing/framing.go @@ -0,0 +1,322 @@ +package framing + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + + "github.com/v2fly/riverrun/common/csrand" + "github.com/v2fly/riverrun/common/drbg" + "github.com/v2fly/riverrun/common/log" +) + +const ( + // MaximumSegmentLength is the length of the largest possible segment + // including overhead. + MaximumSegmentLength = 1500 - 52 + + // LengthLength is the number of bytes used to represent length + LengthLength = 2 + + // TypeLength is the number of bytes used to indicate packet type + TypeLength = 1 + + // MaxFrameLength is the maximum frame length + // MaxFrameLength = MaximumSegmentLength - LengthLength + + ConsumeReadSize = MaximumSegmentLength * 16 +) + +// ErrAgain is the error returned when decoding requires more data to continue. +var ErrAgain = errors.New("framing: More data needed to decode") + +// Error returned when Decoder.Decode() failes to authenticate a frame. +var ErrTagMismatch = errors.New("framing: Poly1305 tag mismatch") + +// InvalidPayloadLengthError is the error returned when Encoder.Encode() +// rejects the payload length. +type InvalidPayloadLengthError int + +func (e InvalidPayloadLengthError) Error() string { + return fmt.Sprintf("framing: Invalid payload length: %d", int(e)) +} + +// InvalidPacketLengthError is the error returned when decodePacket detects a +// invalid packet length/ +type InvalidPacketLengthError int + +func (e InvalidPacketLengthError) Error() string { + return fmt.Sprintf("packet: Invalid packet length: %d", int(e)) +} + +type encodeFunc func(frame, payload []byte) (n int, err error) +type chopPayloadFunc func(pktType uint8, payload []byte) []byte +type overheadFunc func(payloadLen int) int +type processLengthFunc func(length uint16) ([]byte, error) + +// BaseEncoder implements the core encoder vars and functions +type BaseEncoder struct { + Drbg *drbg.HashDrbg + MaxPacketPayloadLength int + LengthLength int + PayloadOverhead overheadFunc + + Encode encodeFunc + ProcessLength processLengthFunc + ChopPayload chopPayloadFunc + + Type string +} + +// TODO: Only do this for riverrun encoder + +func (encoder *BaseEncoder) MakePacket(w io.Writer, payload []byte) error { + // Encode the packet in an AEAD frame. + var frame [MaximumSegmentLength]byte + payloadLen := len(payload) + payloadLenWithOverhead0 := payloadLen + encoder.PayloadOverhead(payloadLen) + if len(frame)-encoder.LengthLength < payloadLenWithOverhead0 { + return io.ErrShortBuffer + } + length := uint16(payloadLenWithOverhead0) + lengthMask := encoder.Drbg.NextBlock() + length ^= binary.BigEndian.Uint16(lengthMask) + processedLength, err := encoder.ProcessLength(length) + if err != nil { + return err + } + copy(frame[:encoder.LengthLength], processedLength) + frameLen := encoder.LengthLength + payloadLenWithOverhead0 + payloadLenWithOverhead1, err := encoder.Encode(frame[encoder.LengthLength:], payload[:payloadLen]) + if err != nil { + // All encoder errors are fatal. + return err + } + + if payloadLenWithOverhead0 != payloadLenWithOverhead1 { + panic(fmt.Sprintf("BUG: MakePacket(), frame lengths do not align, %d %d", payloadLenWithOverhead0, payloadLenWithOverhead1)) + } + + wrLen, err := w.Write(frame[:frameLen]) + if err != nil { + return err + } else if wrLen < frameLen { + return io.ErrShortWrite + } + + return nil +} + +// Chop the pending data into payload frames. +func (encoder *BaseEncoder) Chop(b []byte, pktType uint8) (frameBuf bytes.Buffer, n int, err error) { + chopBuf := bytes.NewBuffer(b) + payload := make([]byte, encoder.MaxPacketPayloadLength) + for chopBuf.Len() > 0 { + // Send maximum sized frames. + rdLen := 0 + rdLen, err = chopBuf.Read(payload[:]) + if err != nil { + return frameBuf, 0, err + } else if rdLen == 0 { + panic(fmt.Sprintf("BUG: Chop(), chopping length was 0")) + } + n += rdLen + err = encoder.MakePacket(&frameBuf, encoder.ChopPayload(pktType, payload[:rdLen])) + if err != nil { + return frameBuf, 0, err + } + } + return +} + +type decodeLengthfunc func(lengthBytes []byte) (uint16, error) +type decodePayloadfunc func(frames *bytes.Buffer) ([]byte, error) +type parsePacketFunc func(decoded []byte, decLen int) error +type cleanupfunc func() error +type BaseDecoder struct { + Drbg *drbg.HashDrbg + LengthLength int + MinPayloadLength int + PacketOverhead int + MaxFramePayloadLength int + + NextLength uint16 + NextLengthInvalid bool + + PayloadOverhead overheadFunc + + DecodeLength decodeLengthfunc + DecodePayload decodePayloadfunc + ParsePacket parsePacketFunc + Cleanup cleanupfunc + + ReceiveBuffer *bytes.Buffer + ReceiveDecodedBuffer *bytes.Buffer + readBuffer []byte +} + +func (decoder *BaseDecoder) InitBuffers() { + decoder.ReceiveBuffer = bytes.NewBuffer(nil) + decoder.ReceiveDecodedBuffer = bytes.NewBuffer(nil) + decoder.readBuffer = make([]byte, ConsumeReadSize) +} + +func (decoder *BaseDecoder) GetFrame(frames *bytes.Buffer) (int, []byte, error) { + maximumPayloadLength := MaximumSegmentLength - decoder.LengthLength + singleFrame := make([]byte, maximumPayloadLength) + n, err := io.ReadFull(frames, singleFrame[:decoder.NextLength]) + if err != nil { + return 0, nil, err + } + return n, singleFrame, nil +} + +func (decoder *BaseDecoder) Read(b []byte, conn net.Conn) (n int, err error) { + // If there is no payload from the previous Read() calls, consume data off + // the network. Not all data received is guaranteed to be usable payload, + // so do this in a loop till data is present or an error occurs. + for decoder.ReceiveDecodedBuffer.Len() == 0 { + err = decoder.readPackets(conn) + if err == ErrAgain { + // Don't proagate this back up the call stack if we happen to break + // out of the loop. + err = nil + continue + } else if err != nil { + break + } + } + + // Even if err is set, attempt to do the read anyway so that all decoded + // data gets relayed before the connection is torn down. + if decoder.ReceiveDecodedBuffer.Len() > 0 { + var berr error + n, berr = decoder.ReceiveDecodedBuffer.Read(b) + if err == nil { + // Only propagate berr if there are not more important (fatal) + // errors from the network/crypto/packet processing. + err = berr + } + } + + return +} + +func (decoder *BaseDecoder) readPackets(conn net.Conn) (err error) { + // Attempt to read off the network. + rdLen, rdErr := conn.Read(decoder.readBuffer) + decoder.ReceiveBuffer.Write(decoder.readBuffer[:rdLen]) + + decoded := make([]byte, decoder.MaxFramePayloadLength) + for decoder.ReceiveBuffer.Len() > 0 { + // Decrypt an AEAD frame. + decLen := 0 + decLen, err = decoder.Decode(decoded[:], decoder.ReceiveBuffer) + if err == ErrAgain { + break + } else if err != nil { + break + } else if decLen < decoder.PacketOverhead { + err = InvalidPacketLengthError(decLen) + break + } + + err = decoder.ParsePacket(decoded, decLen) + if err != nil { + break + } + } + + // Read errors (all fatal) take priority over various frame processing + // errors. + if rdErr != nil { + return rdErr + } + + return +} + +// Decode decodes a stream of data and returns the length if any. ErrAgain is +// a temporary failure, all other errors MUST be treated as fatal and the +// session aborted. +func (decoder *BaseDecoder) Decode(data []byte, frames *bytes.Buffer) (int, error) { + + // A length of 0 indicates that we do not know how big the next frame is + // going to be. + if decoder.NextLength == 0 { + // Attempt to pull out the next frame length. + if decoder.LengthLength > frames.Len() { + return 0, ErrAgain + } + + lengthlength := make([]byte, decoder.LengthLength) + _, err := io.ReadFull(frames, lengthlength[:]) + if err != nil { + return 0, err + } + // Deobfuscate the length field. + length, err := decoder.DecodeLength(lengthlength) + if err != nil { + return 0, err + } + lengthMask := decoder.Drbg.NextBlock() + log.Debugf("length (raw): %d, length (mask): %d", length, lengthMask) + length ^= binary.BigEndian.Uint16(lengthMask) + log.Debugf("First nextLength: %d", length) + if MaximumSegmentLength-int(decoder.LengthLength) < int(length) || decoder.MinPayloadLength > int(length) { + // Per "Plaintext Recovery Attacks Against SSH" by + // Martin R. Albrecht, Kenneth G. Paterson and Gaven J. Watson, + // there are a class of attacks againt protocols that use similar + // sorts of framing schemes. + // + // While obfs4 should not allow plaintext recovery (CBC mode is + // not used), attempt to mitigate out of bound frame length errors + // by pretending that the length was a random valid range as per + // the countermeasure suggested by Denis Bider in section 6 of the + // paper. + log.Debugf("Bad length") + decoder.NextLengthInvalid = true + length = uint16(csrand.IntRange(decoder.MinPayloadLength, MaximumSegmentLength-int(decoder.LengthLength))) + } + log.Debugf("Out nextLength: %d", length) + decoder.NextLength = length + } + + if int(decoder.NextLength) > frames.Len() { + return 0, ErrAgain + } + + decodedPayload, err := decoder.DecodePayload(frames) + if err != nil { + return 0, err + } + copy(data[0:len(decodedPayload)], decodedPayload[:]) + + if decoder.NextLengthInvalid { + // When a random length is used be paranoid. + return 0, ErrTagMismatch + } + + // Clean up and prepare for the next frame. + decoder.NextLength = 0 + return len(decodedPayload), decoder.Cleanup() +} + +// GenDrbg creates a *drbg.HashDrbg with some safety checks +func GenDrbg(key []byte) *drbg.HashDrbg { + if len(key) != drbg.SeedLength { + panic(fmt.Sprintf("BUG: Failed to initialize DRBG: Invalid Keylength, must be %d (drbg.SeedLength)", drbg.SeedLength)) + } + seed, err := drbg.SeedFromBytes(key[:]) + if err != nil { + panic(fmt.Sprintf("BUG: Failed to initialize DRBG: %s", err)) + } + res, err := drbg.NewHashDrbg(seed) + if err != nil { + panic(fmt.Sprintf("BUG: Failed to initialize DRBG: %s", err)) + } + return res +} diff --git a/common/log/log.go b/common/log/log.go new file mode 100644 index 0000000..a393a5e --- /dev/null +++ b/common/log/log.go @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2014-2015, Yawning Angel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +// Package log implements a simple set of leveled logging wrappers around the +// standard log package. +package log // import "github.com/RACECAR-GU/obfsX/common/log" + +type Logger interface { + Infof(format string, a ...interface{}) + Debugf(format string, a ...interface{}) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..bc9c2f4 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/v2fly/riverrun + +go 1.21.7 + +require ( + github.com/RACECAR-GU/obfsX v0.0.0-20230217184022-1add4680bcda + github.com/dchest/siphash v1.2.3 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..68f7430 --- /dev/null +++ b/go.sum @@ -0,0 +1,51 @@ +dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +git.schwanenlied.me/yawning/bsaes.git v0.0.0-20190320102049-26d1add596b6/go.mod h1:BWqTsj8PgcPriQJGl7el20J/7TuT1d/hSyFDXMEpoEo= +git.torproject.org/pluggable-transports/goptlib.git v1.0.0/go.mod h1:YT4XMSkuEXbtqlydr9+OxqFAyspUv0Gr9qhM3B++o/Q= +github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/RACECAR-GU/obfsX v0.0.0-20230217184022-1add4680bcda h1:o7+r8ZcyQDHQlFeLL6+zOR82u4M1f9qj4c1l4DVDNfU= +github.com/RACECAR-GU/obfsX v0.0.0-20230217184022-1add4680bcda/go.mod h1:8FvMyONX77SdOXSsF4MHelLSeZx6sVWRp7kKyhVSJXQ= +github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= +github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4= +github.com/dchest/siphash v1.2.3 h1:QXwFc8cFOR2dSa/gE6o/HokBMWtLUaNDVd+22aKHeEA= +github.com/dchest/siphash v1.2.3/go.mod h1:0NvQU092bT0ipiFN++/rXm69QG9tVxLAlQHIXMPAkHc= +github.com/dsnet/compress v0.0.1/go.mod h1:Aw8dCMJ7RioblQeTqt88akK31OvO8Dhf5JflhBbQEHo= +github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= +github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= +github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= +github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8= +gitlab.com/yawning/utls.git v0.0.11-1/go.mod h1:eYdrOOCoedNc3xw50kJ/s8JquyxeS5kr3vkFZFPTI9w= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE= +golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= +golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= +golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= +golang.org/x/net v0.0.0-20190328230028-74de082e2cca/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190329044733-9eb1bfa1ce65/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= +gonum.org/v1/gonum v0.0.0-20191009222026-5d5638e6749a/go.mod h1:9mxDZsDKxgMAuccQkewq682L+0eCu4dCN2yonUJTCLU= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= +gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/riverrun.go b/riverrun.go new file mode 100644 index 0000000..c58f515 --- /dev/null +++ b/riverrun.go @@ -0,0 +1,400 @@ +package riverrun + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/binary" + "fmt" + "io" + "math/rand" + "net" + "sync" + + "github.com/v2fly/riverrun/common/ctstretch" + "github.com/v2fly/riverrun/common/drbg" + f "github.com/v2fly/riverrun/common/framing" + "github.com/v2fly/riverrun/common/log" +) + +const ( + PacketTypePayload = iota +) + +// Implements the net.Conn interface +type Conn struct { + // Embeds a net.Conn and inherits its members. + net.Conn + + logger log.Logger + + bias float64 + mss_max int + mss_dev float64 + + Encoder *riverrunEncoder + Decoder *riverrunDecoder +} + +func get_rng(seed *drbg.Seed) (*rand.Rand, error) { + xdrbg, err := drbg.NewHashDrbg(seed) + if err != nil { + return nil, err + } + return rand.New(xdrbg), nil +} + +func get_mss(seed *drbg.Seed) (int, error) { + rng, err := get_rng(seed) + if err != nil { + return 0, err + } + return int(rng.Float64()*float64(800)) + 600, nil +} + +func NewConn(conn net.Conn, isServer bool, seed *drbg.Seed, logger log.Logger) (*Conn, error) { + + rng, err := get_rng(seed) + + key := make([]byte, 16) + rng.Read(key) + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + // We select the minimal expansion factors + // The full range is commented out + compressedBlockBits := uint64(16) // uint64((rng.Intn(2) + 1) * 8) + + var expandedBlockBits uint64 + var expandedBlockBits8 uint64 + if compressedBlockBits == 8 { + expandedBlockBits = uint64((rng.Intn(6) + 3) * 8) + expandedBlockBits8 = expandedBlockBits + } else { + expandedBlockBits = 32 // uint64((rng.Intn(3) + 2) * 16) + expandedBlockBits8 = expandedBlockBits / 2 + } + + bias := rng.Float64()*.2 + .1 // Targeting entropy of 4-7 based on observations + + logger.Infof("rr: Set bias to %f, compressed block bits to %d, expanded block bits to %d", bias, compressedBlockBits, expandedBlockBits) + + iv := make([]byte, block.BlockSize()) + rng.Read(iv) + table8, table16, err := getTables(expandedBlockBits8, expandedBlockBits, bias, key, block, iv, logger) + if err != nil { + return nil, err + } + + var readStream, writeStream cipher.Stream + rng.Read(iv) + stream := cipher.NewCTR(block, iv) + readKey := make([]byte, drbg.SeedLength) + writeKey := make([]byte, drbg.SeedLength) + logger.Debugf("riverrun: r/w keys made") + + if isServer { + readStream = stream + rng.Read(iv) + writeStream = cipher.NewCTR(block, iv) + rng.Read(readKey) + rng.Read(writeKey) + } else { + writeStream = stream + rng.Read(iv) + readStream = cipher.NewCTR(block, iv) + rng.Read(writeKey) + rng.Read(readKey) + } + logger.Debugf("riverrun: Loaded keys properly") + rr := new(Conn) + rr.Conn = conn + rr.logger = logger + rr.bias = bias + rr.mss_max, err = get_mss(seed) + if err != nil { + return nil, err + } + rr.mss_dev = rng.Float64() * 4 + logger.Infof("Set mss_max to %v, mss_dev to %v", rr.mss_max, rr.mss_dev) + // Encoder + rr.Encoder = newRiverrunEncoder(writeKey, writeStream, table8, table16, compressedBlockBits, expandedBlockBits, logger) + logger.Debugf("riverrun: Encoder initialized") + // Decoder + rr.Decoder = newRiverrunDecoder(readKey, readStream, ctstretch.InvertTable(table8), ctstretch.InvertTable(table16), compressedBlockBits, expandedBlockBits, logger) + logger.Debugf("riverrun: Initialized") + return rr, nil +} + +var cache8 map[string][]uint64 +var cache16 map[string][]uint64 +var mutex = &sync.Mutex{} + +func getTables(expandedBlockBits8 uint64, expandedBlockBits uint64, bias float64, key []byte, block cipher.Block, iv []byte, logger log.Logger) ([]uint64, []uint64, error) { + + mutex.Lock() + if cache8 == nil { + cache8 = make(map[string][]uint64) + } + if cache16 == nil { + cache16 = make(map[string][]uint64) + } + mutex.Unlock() + + mutex.Lock() + table8, ok := cache8[string(key)] + mutex.Unlock() + if ok { + mutex.Lock() + table16, ok := cache16[string(key)] + mutex.Unlock() + if ok { + logger.Debugf("riverrun: using cached tables") + return table8, table16, nil + } + } + + logger.Debugf("riverrun: Generating fresh tables") + stream := cipher.NewCTR(block, iv) + + table8, err := ctstretch.SampleBiasedStrings(expandedBlockBits8, 256, bias, stream) + if err != nil { + return nil, nil, err + } + logger.Debugf("riverrun: table8 prepped") + table16, err := ctstretch.SampleBiasedStrings(expandedBlockBits, 65536, bias, stream) + if err != nil { + return nil, nil, err + } + logger.Debugf("riverrun: table16 prepped") + + mutex.Lock() + cache8[string(key)] = table8 + cache16[string(key)] = table16 + mutex.Unlock() + + return table8, table16, nil +} + +type riverrunEncoder struct { + f.BaseEncoder + + logger log.Logger + + writeStream cipher.Stream + + table8 []uint64 + table16 []uint64 + + compressedBlockBits uint64 + expandedBlockBits uint64 +} + +func (encoder *riverrunEncoder) payloadOverhead(payloadLen int) int { + return int(ctstretch.ExpandedNBytes(uint64(payloadLen), encoder.compressedBlockBits, encoder.expandedBlockBits)) - payloadLen +} +func (decoder *riverrunDecoder) payloadOverhead(payloadLen int) int { + return int(ctstretch.ExpandedNBytes(uint64(payloadLen), decoder.compressedBlockBits, decoder.expandedBlockBits)) - payloadLen +} + +func newRiverrunEncoder(key []byte, writeStream cipher.Stream, table8, table16 []uint64, compressedBlockBits, expandedBlockBits uint64, logger log.Logger) *riverrunEncoder { + encoder := new(riverrunEncoder) + + encoder.Drbg = f.GenDrbg(key[:]) + encoder.MaxPacketPayloadLength = int(ctstretch.CompressedNBytes_floor(f.MaximumSegmentLength-ctstretch.ExpandedNBytes(uint64(f.LengthLength), compressedBlockBits, expandedBlockBits), expandedBlockBits, compressedBlockBits)) + encoder.LengthLength = int(ctstretch.ExpandedNBytes(uint64(f.LengthLength), compressedBlockBits, expandedBlockBits)) + encoder.PayloadOverhead = encoder.payloadOverhead + + encoder.Encode = encoder.encode + encoder.ProcessLength = encoder.processLength + encoder.ChopPayload = encoder.makePayload + + encoder.writeStream = writeStream + encoder.table8 = table8 + encoder.table16 = table16 + encoder.compressedBlockBits = compressedBlockBits + encoder.expandedBlockBits = expandedBlockBits + + encoder.Type = "rr" + + encoder.logger = logger + + return encoder +} + +func (encoder *riverrunEncoder) processLength(length uint16) ([]byte, error) { + lengthBytes := make([]byte, f.LengthLength) + binary.BigEndian.PutUint16(lengthBytes[:], length) + lengthBytesEncoded := make([]byte, encoder.LengthLength) + err := ctstretch.ExpandBytes(lengthBytes[:], lengthBytesEncoded, encoder.compressedBlockBits, encoder.expandedBlockBits, encoder.table16, encoder.table8, encoder.writeStream, rand.Int()) + return lengthBytesEncoded, err +} + +func (encoder *riverrunEncoder) encode(frame, payload []byte) (n int, err error) { + tb := rand.Int() + expandedNBytes := int(ctstretch.ExpandedNBytes(uint64(len(payload)), encoder.compressedBlockBits, encoder.expandedBlockBits)) + frameLen := encoder.LengthLength + expandedNBytes + encoder.logger.Debugf("Encoding frame of length %d, with payload of length %d. TB: %d", frameLen, expandedNBytes, tb) + err = ctstretch.ExpandBytes(payload[:], frame, encoder.compressedBlockBits, encoder.expandedBlockBits, encoder.table16, encoder.table8, encoder.writeStream, tb) + if err != nil { + return 0, err + } + return expandedNBytes, err +} +func (encoder *riverrunEncoder) makePayload(pktType uint8, payload []byte) []byte { + if pktType != PacketTypePayload { + panic(fmt.Sprintf("BUG: pktType was not packetTypePayload for Riverrun")) + } + return payload[:] +} + +type riverrunDecoder struct { + f.BaseDecoder + + readStream cipher.Stream + + revTable8 map[uint64]uint64 + revTable16 map[uint64]uint64 + + compressedBlockBits uint64 + expandedBlockBits uint64 + + logger log.Logger +} + +func newRiverrunDecoder(key []byte, readStream cipher.Stream, revTable8, revTable16 map[uint64]uint64, compressedBlockBits, expandedBlockBits uint64, logger log.Logger) *riverrunDecoder { + decoder := new(riverrunDecoder) + + decoder.Drbg = f.GenDrbg(key[:]) + decoder.LengthLength = int(ctstretch.ExpandedNBytes(uint64(f.LengthLength), compressedBlockBits, expandedBlockBits)) + decoder.MinPayloadLength = int(ctstretch.ExpandedNBytes(uint64(1), compressedBlockBits, expandedBlockBits)) + decoder.PacketOverhead = 0 // f.LengthLength + decoder.MaxFramePayloadLength = f.MaximumSegmentLength - decoder.LengthLength + + // NextLength is set programatically + // NextLengthInvalid is set programatically + + decoder.PayloadOverhead = decoder.payloadOverhead + + decoder.DecodeLength = decoder.decodeLength + decoder.DecodePayload = decoder.decodePayload + decoder.ParsePacket = decoder.parsePacket + decoder.Cleanup = decoder.cleanup + + decoder.InitBuffers() + + decoder.readStream = readStream + decoder.revTable8 = revTable8 + decoder.revTable16 = revTable16 + decoder.compressedBlockBits = compressedBlockBits + decoder.expandedBlockBits = expandedBlockBits + + decoder.logger = logger + return decoder +} + +func (decoder *riverrunDecoder) cleanup() error { + return nil +} + +func (decoder *riverrunDecoder) decodeLength(lengthBytes []byte) (uint16, error) { + var decodedBytes [f.LengthLength]byte + err := decoder.compressBytes(lengthBytes[:decoder.LengthLength], decodedBytes[:]) + if err != nil { + return 0, err + } + return binary.BigEndian.Uint16(decodedBytes[:f.LengthLength]), err +} + +func (decoder *riverrunDecoder) parsePacket(decoded []byte, decLen int) error { + /* + originalNBytes := binary.BigEndian.Uint16(decoded[:f.LengthLength]) // TODO: Ensure this is encoded + if int(originalNBytes) > decLen-decoder.PacketOverhead { + return f.InvalidPayloadLengthError(int(originalNBytes)) + } + */ + decoder.ReceiveDecodedBuffer.Write(decoded[decoder.PacketOverhead:decLen]) + return nil +} + +func (decoder *riverrunDecoder) decodePayload(frames *bytes.Buffer) ([]byte, error) { + //var frame []byte + //var frameLen int + frameLen, frame, err := decoder.GetFrame(frames) + if err != nil { + return nil, err + } + + compressedNBytes := ctstretch.CompressedNBytes(uint64(frameLen), decoder.expandedBlockBits, decoder.compressedBlockBits) + decodedPayload := make([]byte, compressedNBytes) + err = decoder.compressBytes(frame[:frameLen], decodedPayload[:compressedNBytes]) + if err != nil { + decoder.logger.Debugf("Max payload length is %d", int(ctstretch.CompressedNBytes_floor(f.MaximumSegmentLength-ctstretch.ExpandedNBytes(uint64(f.LengthLength), decoder.compressedBlockBits, decoder.expandedBlockBits), decoder.expandedBlockBits, decoder.compressedBlockBits))) + decoder.logger.Debugf("CompressedNBytes: %d", compressedNBytes) + decoder.logger.Debugf("Got payload of len %d", frameLen) + return nil, err + } + + return decodedPayload[:], nil +} + +func (decoder *riverrunDecoder) compressBytes(raw, res []byte) error { + return ctstretch.CompressBytes(raw, res, decoder.expandedBlockBits, decoder.compressedBlockBits, decoder.revTable16, decoder.revTable8, decoder.readStream, rand.Int()) +} + +func (rr *Conn) nextLength() int { + noise := rand.NormFloat64() * rr.mss_dev + if noise < 0 { + noise = noise * -1 + } + if int(noise) >= rr.mss_max { + return rr.nextLength() + } + return rr.mss_max - int(noise) +} + +func (rr *Conn) Write(b []byte) (n int, err error) { + + // XXX: n could be more accurate + var frameBuf bytes.Buffer + frameBuf, n, err = rr.Encoder.Chop(b, PacketTypePayload) + if err != nil { + return + } + + // We do obfuscation here - experimental results found the + // constant near MSS sizes were detectable + for { + nextLength := rr.nextLength() + toWire := make([]byte, nextLength) + + s, e := frameBuf.Read(toWire) + if e != nil { + if e != io.EOF { + err = e + } + return + } + + rr.logger.Debugf("Next length: %v", s) + + _, err = rr.Conn.Write(toWire[:s]) + if err != nil { + return + } + } + + //log.Debugf("Riverrun: %d expanded to %d ->", n, lowerConnN) + // TODO: What does spec say about returned numbers? + // Should they be bytes written, or the raw bytes before expansion expanded? + // Idea: Bytes written (raw), Bytes written (processed), err - raw bytes is equivalent to old n +} + +func (rr *Conn) Read(b []byte) (int, error) { + //originalLen := len(b) + n, err := rr.Decoder.Read(b, rr.Conn) + //log.Debugf("Riverrun: %d compressed to %d <-", originalLen, n) + return n, err +}