Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: new Shadowsocks validator #629

Merged
merged 9 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 5 additions & 20 deletions proxy/shadowsocks/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"crypto/md5"
"crypto/sha1"
"io"
"reflect"
"strconv"

"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/hkdf"
Expand All @@ -28,6 +26,10 @@ type MemoryAccount struct {
replayFilter antireplay.GeneralizedReplayFilter
}

var (
ErrIVNotUnique = newError("IV is not unique")
)

// Equals implements protocol.Account.Equals().
func (a *MemoryAccount) Equals(another protocol.Account) bool {
if account, ok := another.(*MemoryAccount); ok {
Expand All @@ -43,24 +45,7 @@ func (a *MemoryAccount) CheckIV(iv []byte) error {
if a.replayFilter.Check(iv) {
return nil
}
return newError("IV is not unique")
}

func (a *MemoryAccount) GetCipherName() string {
switch a.Cipher.(type) {
case *AEADCipher:
switch reflect.ValueOf(a.Cipher.(*AEADCipher).AEADAuthCreator).Pointer() {
case reflect.ValueOf(createAesGcm).Pointer():
keyBytes := a.Cipher.(*AEADCipher).KeyBytes
return "AES_" + strconv.FormatInt(int64(keyBytes*8), 10) + "_GCM"
case reflect.ValueOf(createChaCha20Poly1305).Pointer():
return "CHACHA20_POLY1305"
}
case *NoneCipher:
return "NONE"
}

return ""
return ErrIVNotUnique
}

func createAesGcm(key []byte) cipher.AEAD {
Expand Down
132 changes: 44 additions & 88 deletions proxy/shadowsocks/protocol.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package shadowsocks

import (
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"hash/crc32"
"io"

"github.com/xtls/xray-core/common"
Expand Down Expand Up @@ -54,91 +50,67 @@ func (r *FullReader) Read(p []byte) (n int, err error) {

// ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts.
func ReadTCPSession(validator *Validator, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
hashkdf := hmac.New(sha256.New, []byte("SSBSKDF"))

behaviorSeed := crc32.ChecksumIEEE(hashkdf.Sum(nil))

behaviorSeed := validator.GetBehaviorSeed()
behaviorRand := dice.NewDeterministicDice(int64(behaviorSeed))
BaseDrainSize := behaviorRand.Roll(3266)
RandDrainMax := behaviorRand.Roll(64) + 1
RandDrainRolled := dice.Roll(RandDrainMax)
DrainSize := BaseDrainSize + 16 + 38 + RandDrainRolled
readSizeRemain := DrainSize

var r2 buf.Reader
var r buf.Reader
buffer := buf.New()
defer buffer.Release()

var user *protocol.MemoryUser
var ivLen int32
var iv []byte
var err error

count := validator.Count()
if count == 0 {
if _, err := buffer.ReadFullFrom(reader, 50); err != nil {
readSizeRemain -= int(buffer.Len())
DrainConnN(reader, readSizeRemain)
return nil, nil, newError("invalid user")
} else if count > 1 {
var aead cipher.AEAD

if _, err := buffer.ReadFullFrom(reader, 50); err != nil {
readSizeRemain -= int(buffer.Len())
DrainConnN(reader, readSizeRemain)
return nil, nil, newError("failed to read 50 bytes").Base(err)
}
return nil, nil, newError("failed to read 50 bytes").Base(err)
}

bs := buffer.Bytes()
user, aead, _, ivLen, err = validator.Get(bs, protocol.RequestCommandTCP)
bs := buffer.Bytes()
user, aead, _, ivLen, err := validator.Get(bs, protocol.RequestCommandTCP)

if user != nil {
if ivLen > 0 {
iv = append([]byte(nil), bs[:ivLen]...)
}
reader = &FullReader{reader, bs[ivLen:]}
switch err {
case ErrNotFound:
readSizeRemain -= int(buffer.Len())
DrainConnN(reader, readSizeRemain)
return nil, nil, newError("failed to match an user").Base(err)
case ErrIVNotUnique:
readSizeRemain -= int(buffer.Len())
DrainConnN(reader, readSizeRemain)
return nil, nil, newError("failed iv check").Base(err)
default:
reader = &FullReader{reader, bs[ivLen:]}
readSizeRemain -= int(ivLen)

if aead != nil {
auth := &crypto.AEADAuthenticator{
AEAD: aead,
NonceGenerator: crypto.GenerateInitialAEADNonce(),
}
r2 = crypto.NewAuthenticationReader(auth, &crypto.AEADChunkSizeParser{
r = crypto.NewAuthenticationReader(auth, &crypto.AEADChunkSizeParser{
Auth: auth,
}, reader, protocol.TransferTypeStream, nil)
} else {
readSizeRemain -= int(buffer.Len())
DrainConnN(reader, readSizeRemain)
return nil, nil, newError("failed to match an user").Base(err)
}
} else {
user, ivLen = validator.GetOnlyUser()
account := user.Account.(*MemoryAccount)
hashkdf.Write(account.Key)
if ivLen > 0 {
if _, err := buffer.ReadFullFrom(reader, ivLen); err != nil {
readSizeRemain -= int(buffer.Len())
account := user.Account.(*MemoryAccount)
iv := append([]byte(nil), buffer.BytesTo(ivLen)...)
r, err = account.Cipher.NewDecryptionReader(account.Key, iv, reader)
if err != nil {
DrainConnN(reader, readSizeRemain)
return nil, nil, newError("failed to read IV").Base(err)
return nil, nil, newError("failed to initialize decoding stream").Base(err).AtError()
}
iv = append([]byte(nil), buffer.BytesTo(ivLen)...)
}

r, err := account.Cipher.NewDecryptionReader(account.Key, iv, reader)
if err != nil {
readSizeRemain -= int(buffer.Len())
DrainConnN(reader, readSizeRemain)
return nil, nil, newError("failed to initialize decoding stream").Base(err).AtError()
}
r2 = r
}

br := &buf.BufferedReader{Reader: r2}
br := &buf.BufferedReader{Reader: r}

request := &protocol.RequestHeader{
Version: Version,
User: user,
Command: protocol.RequestCommandTCP,
}

readSizeRemain -= int(buffer.Len())
buffer.Clear()

addr, port, err := addrParser.ReadAddressPort(buffer, br)
Expand All @@ -157,13 +129,6 @@ func ReadTCPSession(validator *Validator, reader io.Reader) (*protocol.RequestHe
return nil, nil, newError("invalid remote address.")
}

account := user.Account.(*MemoryAccount)
if ivError := account.CheckIV(iv); ivError != nil {
readSizeRemain -= int(buffer.Len())
DrainConnN(reader, readSizeRemain)
return nil, nil, newError("failed iv check").Base(ivError)
}

return request, br, nil
}

Expand Down Expand Up @@ -273,34 +238,25 @@ func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.Reque
return nil, nil, newError("len(bs) <= 32")
}

var user *protocol.MemoryUser
var err error

count := validator.Count()
if count == 0 {
return nil, nil, newError("invalid user")
} else if count > 1 {
var d []byte
user, _, d, _, err = validator.Get(bs, protocol.RequestCommandUDP)

if user != nil {
user, _, d, _, err := validator.Get(bs, protocol.RequestCommandUDP)
switch err {
case ErrIVNotUnique:
return nil, nil, newError("failed iv check").Base(err)
case ErrNotFound:
return nil, nil, newError("failed to match an user").Base(err)
default:
account := user.Account.(*MemoryAccount)
if account.Cipher.IsAEAD() {
payload.Clear()
payload.Write(d)
} else {
return nil, nil, newError("failed to decrypt UDP payload").Base(err)
}
} else {
user, _ = validator.GetOnlyUser()
account := user.Account.(*MemoryAccount)

var iv []byte
if !account.Cipher.IsAEAD() && account.Cipher.IVSize() > 0 {
// Keep track of IV as it gets removed from payload in DecodePacket.
iv = make([]byte, account.Cipher.IVSize())
copy(iv, payload.BytesTo(account.Cipher.IVSize()))
}
if err = account.Cipher.DecodePacket(account.Key, payload); err != nil {
return nil, nil, newError("failed to decrypt UDP payload").Base(err)
if account.Cipher.IVSize() > 0 {
iv := make([]byte, account.Cipher.IVSize())
copy(iv, payload.BytesTo(account.Cipher.IVSize()))
}
if err = account.Cipher.DecodePacket(account.Key, payload); err != nil {
return nil, nil, newError("failed to decrypt UDP payload").Base(err)
}
}
}

Expand Down
4 changes: 0 additions & 4 deletions proxy/shadowsocks/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
panic("no inbound metadata")
}

if s.validator.Count() == 1 {
inbound.User, _ = s.validator.GetOnlyUser()
}

var dest *net.Destination

reader := buf.NewPacketReader(conn)
Expand Down
Loading