diff --git a/common/net/cnc/connection.go b/common/net/cnc/connection.go index bdae5409673b..519918fd47e8 100644 --- a/common/net/cnc/connection.go +++ b/common/net/cnc/connection.go @@ -10,46 +10,46 @@ import ( "github.com/xtls/xray-core/common/signal/done" ) -type ConnectionOption func(*connection) +type ConnectionOption func(*Connection) func ConnectionLocalAddr(a net.Addr) ConnectionOption { - return func(c *connection) { + return func(c *Connection) { c.local = a } } func ConnectionRemoteAddr(a net.Addr) ConnectionOption { - return func(c *connection) { + return func(c *Connection) { c.remote = a } } func ConnectionInput(writer io.Writer) ConnectionOption { - return func(c *connection) { + return func(c *Connection) { c.writer = buf.NewWriter(writer) } } func ConnectionInputMulti(writer buf.Writer) ConnectionOption { - return func(c *connection) { + return func(c *Connection) { c.writer = writer } } func ConnectionOutput(reader io.Reader) ConnectionOption { - return func(c *connection) { + return func(c *Connection) { c.reader = &buf.BufferedReader{Reader: buf.NewReader(reader)} } } func ConnectionOutputMulti(reader buf.Reader) ConnectionOption { - return func(c *connection) { + return func(c *Connection) { c.reader = &buf.BufferedReader{Reader: reader} } } func ConnectionOutputMultiUDP(reader buf.Reader) ConnectionOption { - return func(c *connection) { + return func(c *Connection) { c.reader = &buf.BufferedReader{ Reader: reader, Splitter: buf.SplitFirstBytes, @@ -58,13 +58,13 @@ func ConnectionOutputMultiUDP(reader buf.Reader) ConnectionOption { } func ConnectionOnClose(n io.Closer) ConnectionOption { - return func(c *connection) { + return func(c *Connection) { c.onClose = n } } func NewConnection(opts ...ConnectionOption) net.Conn { - c := &connection{ + c := &Connection{ done: done.New(), local: &net.TCPAddr{ IP: []byte{0, 0, 0, 0}, @@ -83,7 +83,7 @@ func NewConnection(opts ...ConnectionOption) net.Conn { return c } -type connection struct { +type Connection struct { reader *buf.BufferedReader writer buf.Writer done *done.Instance @@ -92,17 +92,17 @@ type connection struct { remote net.Addr } -func (c *connection) Read(b []byte) (int, error) { +func (c *Connection) Read(b []byte) (int, error) { return c.reader.Read(b) } // ReadMultiBuffer implements buf.Reader. -func (c *connection) ReadMultiBuffer() (buf.MultiBuffer, error) { +func (c *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { return c.reader.ReadMultiBuffer() } // Write implements net.Conn.Write(). -func (c *connection) Write(b []byte) (int, error) { +func (c *Connection) Write(b []byte) (int, error) { if c.done.Done() { return 0, io.ErrClosedPipe } @@ -113,7 +113,7 @@ func (c *connection) Write(b []byte) (int, error) { return l, c.writer.WriteMultiBuffer(mb) } -func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error { +func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { if c.done.Done() { buf.ReleaseMulti(mb) return io.ErrClosedPipe @@ -123,7 +123,7 @@ func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error { } // Close implements net.Conn.Close(). -func (c *connection) Close() error { +func (c *Connection) Close() error { common.Must(c.done.Close()) common.Interrupt(c.reader) common.Close(c.writer) @@ -135,26 +135,26 @@ func (c *connection) Close() error { } // LocalAddr implements net.Conn.LocalAddr(). -func (c *connection) LocalAddr() net.Addr { +func (c *Connection) LocalAddr() net.Addr { return c.local } // RemoteAddr implements net.Conn.RemoteAddr(). -func (c *connection) RemoteAddr() net.Addr { +func (c *Connection) RemoteAddr() net.Addr { return c.remote } // SetDeadline implements net.Conn.SetDeadline(). -func (c *connection) SetDeadline(t time.Time) error { +func (c *Connection) SetDeadline(t time.Time) error { return nil } // SetReadDeadline implements net.Conn.SetReadDeadline(). -func (c *connection) SetReadDeadline(t time.Time) error { +func (c *Connection) SetReadDeadline(t time.Time) error { return nil } // SetWriteDeadline implements net.Conn.SetWriteDeadline(). -func (c *connection) SetWriteDeadline(t time.Time) error { +func (c *Connection) SetWriteDeadline(t time.Time) error { return nil } diff --git a/infra/conf/freedom.go b/infra/conf/freedom.go index 65930a2e8a29..82d2c9a8bc2a 100644 --- a/infra/conf/freedom.go +++ b/infra/conf/freedom.go @@ -10,8 +10,8 @@ import ( v2net "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/proxy/freedom" - "google.golang.org/protobuf/proto" "github.com/xtls/xray-core/transport/internet" + "google.golang.org/protobuf/proto" ) type FreedomConfig struct { diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index b4458096b458..823de05b40ae 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -18,6 +18,8 @@ import ( "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/common/serial" "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/finalmask/fragment" + "github.com/xtls/xray-core/transport/internet/finalmask/header/custom" "github.com/xtls/xray-core/transport/internet/finalmask/header/dns" "github.com/xtls/xray-core/transport/internet/finalmask/header/dtls" "github.com/xtls/xray-core/transport/internet/finalmask/header/srtp" @@ -26,6 +28,7 @@ import ( "github.com/xtls/xray-core/transport/internet/finalmask/header/wireguard" "github.com/xtls/xray-core/transport/internet/finalmask/mkcp/aes128gcm" "github.com/xtls/xray-core/transport/internet/finalmask/mkcp/original" + "github.com/xtls/xray-core/transport/internet/finalmask/noise" "github.com/xtls/xray-core/transport/internet/finalmask/salamander" "github.com/xtls/xray-core/transport/internet/finalmask/xdns" "github.com/xtls/xray-core/transport/internet/finalmask/xicmp" @@ -1276,8 +1279,48 @@ func (c *SocketConfig) Build() (*internet.SocketConfig, error) { }, nil } +func PraseByteSlice(data json.RawMessage, typ string) ([]byte, error) { + switch strings.ToLower(typ) { + case "", "array": + if len(data) == 0 { + return data, nil + } + var packet []byte + if err := json.Unmarshal(data, &packet); err != nil { + return nil, err + } + return packet, nil + case "str": + var str string + if err := json.Unmarshal(data, &str); err != nil { + return nil, err + } + return []byte(str), nil + case "hex": + var str string + if err := json.Unmarshal(data, &str); err != nil { + return nil, err + } + return hex.DecodeString(str) + case "base64": + var str string + if err := json.Unmarshal(data, &str); err != nil { + return nil, err + } + return base64.StdEncoding.DecodeString(str) + default: + return nil, errors.New("unknown type") + } +} + var ( + tcpmaskLoader = NewJSONConfigLoader(ConfigCreatorCache{ + "header-custom": func() interface{} { return new(HeaderCustomTCP) }, + "fragment": func() interface{} { return new(FragmentMask) }, + }, "type", "settings") + udpmaskLoader = NewJSONConfigLoader(ConfigCreatorCache{ + "header-custom": func() interface{} { return new(HeaderCustomUDP) }, "header-dns": func() interface{} { return new(Dns) }, "header-dtls": func() interface{} { return new(Dtls) }, "header-srtp": func() interface{} { return new(Srtp) }, @@ -1286,12 +1329,245 @@ var ( "header-wireguard": func() interface{} { return new(Wireguard) }, "mkcp-original": func() interface{} { return new(Original) }, "mkcp-aes128gcm": func() interface{} { return new(Aes128Gcm) }, + "noise": func() interface{} { return new(NoiseMask) }, "salamander": func() interface{} { return new(Salamander) }, "xdns": func() interface{} { return new(Xdns) }, "xicmp": func() interface{} { return new(Xicmp) }, }, "type", "settings") ) +type TCPItem struct { + Delay Int32Range `json:"delay"` + Rand int32 `json:"rand"` + Type string `json:"type"` + Packet json.RawMessage `json:"packet"` +} + +type HeaderCustomTCP struct { + Clients [][]TCPItem `json:"clients"` + Servers [][]TCPItem `json:"servers"` + Errors [][]TCPItem `json:"errors"` +} + +func (c *HeaderCustomTCP) Build() (proto.Message, error) { + for _, value := range c.Clients { + for _, item := range value { + if len(item.Packet) > 0 && item.Rand > 0 { + return nil, errors.New("len(item.Packet) > 0 && item.Rand > 0") + } + } + } + for _, value := range c.Servers { + for _, item := range value { + if len(item.Packet) > 0 && item.Rand > 0 { + return nil, errors.New("len(item.Packet) > 0 && item.Rand > 0") + } + } + } + for _, value := range c.Errors { + for _, item := range value { + if len(item.Packet) > 0 && item.Rand > 0 { + return nil, errors.New("len(item.Packet) > 0 && item.Rand > 0") + } + } + } + + clients := make([]*custom.TCPSequence, len(c.Clients)) + for i, value := range c.Clients { + clients[i] = &custom.TCPSequence{} + for _, item := range value { + var err error + if item.Packet, err = PraseByteSlice(item.Packet, item.Type); err != nil { + return nil, err + } + clients[i].Sequence = append(clients[i].Sequence, &custom.TCPItem{ + DelayMin: int64(item.Delay.From), + DelayMax: int64(item.Delay.To), + Rand: item.Rand, + Packet: item.Packet, + }) + } + } + + servers := make([]*custom.TCPSequence, len(c.Servers)) + for i, value := range c.Servers { + servers[i] = &custom.TCPSequence{} + for _, item := range value { + var err error + if item.Packet, err = PraseByteSlice(item.Packet, item.Type); err != nil { + return nil, err + } + servers[i].Sequence = append(servers[i].Sequence, &custom.TCPItem{ + DelayMin: int64(item.Delay.From), + DelayMax: int64(item.Delay.To), + Rand: item.Rand, + Packet: item.Packet, + }) + } + } + + errors := make([]*custom.TCPSequence, len(c.Errors)) + for i, value := range c.Errors { + errors[i] = &custom.TCPSequence{} + for _, item := range value { + var err error + if item.Packet, err = PraseByteSlice(item.Packet, item.Type); err != nil { + return nil, err + } + errors[i].Sequence = append(errors[i].Sequence, &custom.TCPItem{ + DelayMin: int64(item.Delay.From), + DelayMax: int64(item.Delay.To), + Rand: item.Rand, + Packet: item.Packet, + }) + } + } + + return &custom.TCPConfig{ + Clients: clients, + Servers: servers, + Errors: errors, + }, nil +} + +type FragmentMask struct { + Packets string `json:"packets"` + Length Int32Range `json:"length"` + Delay Int32Range `json:"delay"` + MaxSplit Int32Range `json:"maxSplit"` +} + +func (c *FragmentMask) Build() (proto.Message, error) { + config := &fragment.Config{} + + switch strings.ToLower(c.Packets) { + case "tlshello": + config.PacketsFrom = 0 + config.PacketsTo = 1 + case "": + config.PacketsFrom = 0 + config.PacketsTo = 0 + default: + from, to, err := ParseRangeString(c.Packets) + if err != nil { + return nil, errors.New("Invalid PacketsFrom").Base(err) + } + config.PacketsFrom = int64(from) + config.PacketsTo = int64(to) + if config.PacketsFrom == 0 { + return nil, errors.New("PacketsFrom can't be 0") + } + } + + config.LengthMin = int64(c.Length.From) + config.LengthMax = int64(c.Length.To) + if config.LengthMin == 0 { + return nil, errors.New("LengthMin can't be 0") + } + + config.DelayMin = int64(c.Delay.From) + config.DelayMax = int64(c.Delay.To) + + config.MaxSplitMin = int64(c.MaxSplit.From) + config.MaxSplitMax = int64(c.MaxSplit.To) + + return config, nil +} + +type NoiseItem struct { + Rand Int32Range `json:"rand"` + Type string `json:"type"` + Packet json.RawMessage `json:"packet"` + Delay Int32Range `json:"delay"` +} + +type NoiseMask struct { + Reset Int32Range `json:"reset"` + Noise []NoiseItem `json:"noise"` +} + +func (c *NoiseMask) Build() (proto.Message, error) { + for _, item := range c.Noise { + if len(item.Packet) > 0 && item.Rand.To > 0 { + return nil, errors.New("len(item.Packet) > 0 && item.Rand.To > 0") + } + } + + noiseSlice := make([]*noise.Item, 0, len(c.Noise)) + for _, item := range c.Noise { + var err error + if item.Packet, err = PraseByteSlice(item.Packet, item.Type); err != nil { + return nil, err + } + noiseSlice = append(noiseSlice, &noise.Item{ + RandMin: int64(item.Rand.From), + RandMax: int64(item.Rand.To), + Packet: item.Packet, + DelayMin: int64(item.Delay.From), + DelayMax: int64(item.Delay.To), + }) + } + + return &noise.Config{ + ResetMin: int64(c.Reset.From), + ResetMax: int64(c.Reset.To), + Items: noiseSlice, + }, nil +} + +type UDPItem struct { + Rand int32 `json:"rand"` + Type string `json:"type"` + Packet json.RawMessage `json:"packet"` +} + +type HeaderCustomUDP struct { + Client []UDPItem `json:"client"` + Server []UDPItem `json:"server"` +} + +func (c *HeaderCustomUDP) Build() (proto.Message, error) { + for _, item := range c.Client { + if len(item.Packet) > 0 && item.Rand > 0 { + return nil, errors.New("len(item.Packet) > 0 && item.Rand > 0") + } + } + for _, item := range c.Server { + if len(item.Packet) > 0 && item.Rand > 0 { + return nil, errors.New("len(item.Packet) > 0 && item.Rand > 0") + } + } + + client := make([]*custom.UDPItem, 0, len(c.Client)) + for _, item := range c.Client { + var err error + if item.Packet, err = PraseByteSlice(item.Packet, item.Type); err != nil { + return nil, err + } + client = append(client, &custom.UDPItem{ + Rand: item.Rand, + Packet: item.Packet, + }) + } + + server := make([]*custom.UDPItem, 0, len(c.Server)) + for _, item := range c.Server { + var err error + if item.Packet, err = PraseByteSlice(item.Packet, item.Type); err != nil { + return nil, err + } + server = append(server, &custom.UDPItem{ + Rand: item.Rand, + Packet: item.Packet, + }) + } + + return &custom.UDPConfig{ + Client: client, + Server: server, + }, nil +} + type Dns struct { Domain string `json:"domain"` } @@ -1403,7 +1679,7 @@ type Mask struct { func (c *Mask) Build(tcp bool) (proto.Message, error) { loader := udpmaskLoader if tcp { - return nil, errors.New("") + loader = tcpmaskLoader } settings := []byte("{}") diff --git a/proxy/proxy.go b/proxy/proxy.go index 29548d9fb120..acda52d9c14e 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -28,6 +28,7 @@ import ( "github.com/xtls/xray-core/proxy/vless/encryption" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/finalmask" "github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tls" @@ -687,6 +688,9 @@ func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) { conn = realityUConn.NetConn() } } + + conn = finalmask.UnwrapTcpMask(conn) + if pc, ok := conn.(*proxyproto.Conn); ok { conn = pc.Raw() // 8192 > 4096, there is no need to process pc's bufReader @@ -788,6 +792,7 @@ func readV(ctx context.Context, reader buf.Reader, writer buf.Writer, timer sign func IsRAWTransportWithoutSecurity(conn stat.Connection) bool { iConn := stat.TryUnwrapStatsConn(conn) + iConn = finalmask.UnwrapTcpMask(iConn) _, ok1 := iConn.(*proxyproto.Conn) _, ok2 := iConn.(*net.TCPConn) _, ok3 := iConn.(*internet.UnixConnWrapper) diff --git a/transport/internet/finalmask/finalmask.go b/transport/internet/finalmask/finalmask.go index d8a289a780e7..3eee635a5a1c 100644 --- a/transport/internet/finalmask/finalmask.go +++ b/transport/internet/finalmask/finalmask.go @@ -1,18 +1,21 @@ package finalmask import ( + "context" "net" + + "github.com/xtls/xray-core/common/errors" ) -type ConnSize interface { - Size() int32 -} +const ( + UDPSize = 4096 + 123 +) type Udpmask interface { UDP() - WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) - WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) + WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) + WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) } type UdpmaskManager struct { @@ -26,27 +29,23 @@ func NewUdpmaskManager(udpmasks []Udpmask) *UdpmaskManager { } func (m *UdpmaskManager) WrapPacketConnClient(raw net.PacketConn) (net.PacketConn, error) { - leaveSize := int32(0) var err error for i, mask := range m.udpmasks { - raw, err = mask.WrapPacketConnClient(raw, i == len(m.udpmasks)-1, leaveSize, i == 0) + raw, err = mask.WrapPacketConnClient(raw, i, len(m.udpmasks)-1) if err != nil { return nil, err } - leaveSize += raw.(ConnSize).Size() } return raw, nil } func (m *UdpmaskManager) WrapPacketConnServer(raw net.PacketConn) (net.PacketConn, error) { - leaveSize := int32(0) var err error for i, mask := range m.udpmasks { - raw, err = mask.WrapPacketConnServer(raw, i == len(m.udpmasks)-1, leaveSize, i == 0) + raw, err = mask.WrapPacketConnServer(raw, i, len(m.udpmasks)-1) if err != nil { return nil, err } - leaveSize += raw.(ConnSize).Size() } return raw, nil } @@ -89,3 +88,54 @@ func (m *TcpmaskManager) WrapConnServer(raw net.Conn) (net.Conn, error) { } return raw, nil } + +func (m *TcpmaskManager) WrapListener(l net.Listener) (net.Listener, error) { + return NewTcpListener(m, l) +} + +type tcpListener struct { + m *TcpmaskManager + net.Listener +} + +func NewTcpListener(m *TcpmaskManager, l net.Listener) (net.Listener, error) { + return &tcpListener{ + m: m, + Listener: l, + }, nil +} + +func (l *tcpListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return conn, err + } + + newConn, err := l.m.WrapConnServer(conn) + if err != nil { + errors.LogDebugInner(context.Background(), err, "mask err") + // conn.Close() + return conn, nil + } + + return newConn, nil +} + +type TcpMaskConn interface { + TcpMaskConn() + RawConn() net.Conn + Splice() bool +} + +func UnwrapTcpMask(conn net.Conn) net.Conn { + for { + if v, ok := conn.(TcpMaskConn); ok { + if !v.Splice() { + return conn + } + conn = v.RawConn() + } else { + return conn + } + } +} diff --git a/transport/internet/finalmask/fragment/config.go b/transport/internet/finalmask/fragment/config.go new file mode 100644 index 000000000000..165f11ca9d25 --- /dev/null +++ b/transport/internet/finalmask/fragment/config.go @@ -0,0 +1,14 @@ +package fragment + +import "net" + +func (c *Config) TCP() { +} + +func (c *Config) WrapConnClient(raw net.Conn) (net.Conn, error) { + return NewConnClient(c, raw, false) +} + +func (c *Config) WrapConnServer(raw net.Conn) (net.Conn, error) { + return NewConnServer(c, raw, true) +} diff --git a/transport/internet/finalmask/fragment/config.pb.go b/transport/internet/finalmask/fragment/config.pb.go new file mode 100644 index 000000000000..c8660f5ec34c --- /dev/null +++ b/transport/internet/finalmask/fragment/config.pb.go @@ -0,0 +1,189 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.33.5 +// source: transport/internet/finalmask/fragment/config.proto + +package fragment + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + PacketsFrom int64 `protobuf:"varint,1,opt,name=packets_from,json=packetsFrom,proto3" json:"packets_from,omitempty"` + PacketsTo int64 `protobuf:"varint,2,opt,name=packets_to,json=packetsTo,proto3" json:"packets_to,omitempty"` + LengthMin int64 `protobuf:"varint,3,opt,name=length_min,json=lengthMin,proto3" json:"length_min,omitempty"` + LengthMax int64 `protobuf:"varint,4,opt,name=length_max,json=lengthMax,proto3" json:"length_max,omitempty"` + DelayMin int64 `protobuf:"varint,5,opt,name=delay_min,json=delayMin,proto3" json:"delay_min,omitempty"` + DelayMax int64 `protobuf:"varint,6,opt,name=delay_max,json=delayMax,proto3" json:"delay_max,omitempty"` + MaxSplitMin int64 `protobuf:"varint,7,opt,name=max_split_min,json=maxSplitMin,proto3" json:"max_split_min,omitempty"` + MaxSplitMax int64 `protobuf:"varint,8,opt,name=max_split_max,json=maxSplitMax,proto3" json:"max_split_max,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_fragment_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_fragment_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_fragment_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetPacketsFrom() int64 { + if x != nil { + return x.PacketsFrom + } + return 0 +} + +func (x *Config) GetPacketsTo() int64 { + if x != nil { + return x.PacketsTo + } + return 0 +} + +func (x *Config) GetLengthMin() int64 { + if x != nil { + return x.LengthMin + } + return 0 +} + +func (x *Config) GetLengthMax() int64 { + if x != nil { + return x.LengthMax + } + return 0 +} + +func (x *Config) GetDelayMin() int64 { + if x != nil { + return x.DelayMin + } + return 0 +} + +func (x *Config) GetDelayMax() int64 { + if x != nil { + return x.DelayMax + } + return 0 +} + +func (x *Config) GetMaxSplitMin() int64 { + if x != nil { + return x.MaxSplitMin + } + return 0 +} + +func (x *Config) GetMaxSplitMax() int64 { + if x != nil { + return x.MaxSplitMax + } + return 0 +} + +var File_transport_internet_finalmask_fragment_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_fragment_config_proto_rawDesc = "" + + "\n" + + "2transport/internet/finalmask/fragment/config.proto\x12*xray.transport.internet.finalmask.fragment\"\x8a\x02\n" + + "\x06Config\x12!\n" + + "\fpackets_from\x18\x01 \x01(\x03R\vpacketsFrom\x12\x1d\n" + + "\n" + + "packets_to\x18\x02 \x01(\x03R\tpacketsTo\x12\x1d\n" + + "\n" + + "length_min\x18\x03 \x01(\x03R\tlengthMin\x12\x1d\n" + + "\n" + + "length_max\x18\x04 \x01(\x03R\tlengthMax\x12\x1b\n" + + "\tdelay_min\x18\x05 \x01(\x03R\bdelayMin\x12\x1b\n" + + "\tdelay_max\x18\x06 \x01(\x03R\bdelayMax\x12\"\n" + + "\rmax_split_min\x18\a \x01(\x03R\vmaxSplitMin\x12\"\n" + + "\rmax_split_max\x18\b \x01(\x03R\vmaxSplitMaxB\xa0\x01\n" + + ".com.xray.transport.internet.finalmask.fragmentP\x01Z?github.com/xtls/xray-core/transport/internet/finalmask/fragment\xaa\x02*Xray.Transport.Internet.Finalmask.Fragmentb\x06proto3" + +var ( + file_transport_internet_finalmask_fragment_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_fragment_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_fragment_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_fragment_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_fragment_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_fragment_config_proto_rawDesc), len(file_transport_internet_finalmask_fragment_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_fragment_config_proto_rawDescData +} + +var file_transport_internet_finalmask_fragment_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_fragment_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.finalmask.fragment.Config +} +var file_transport_internet_finalmask_fragment_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_fragment_config_proto_init() } +func file_transport_internet_finalmask_fragment_config_proto_init() { + if File_transport_internet_finalmask_fragment_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_fragment_config_proto_rawDesc), len(file_transport_internet_finalmask_fragment_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_fragment_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_fragment_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_fragment_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_fragment_config_proto = out.File + file_transport_internet_finalmask_fragment_config_proto_goTypes = nil + file_transport_internet_finalmask_fragment_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/fragment/config.proto b/transport/internet/finalmask/fragment/config.proto new file mode 100644 index 000000000000..62aaf18640a9 --- /dev/null +++ b/transport/internet/finalmask/fragment/config.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.fragment; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Fragment"; +option go_package = "github.com/xtls/xray-core/transport/internet/finalmask/fragment"; +option java_package = "com.xray.transport.internet.finalmask.fragment"; +option java_multiple_files = true; + +message Config { + int64 packets_from = 1; + int64 packets_to = 2; + int64 length_min = 3; + int64 length_max = 4; + int64 delay_min = 5; + int64 delay_max = 6; + int64 max_split_min = 7; + int64 max_split_max = 8; +} \ No newline at end of file diff --git a/transport/internet/finalmask/fragment/conn.go b/transport/internet/finalmask/fragment/conn.go new file mode 100644 index 000000000000..1cadf5f0aef0 --- /dev/null +++ b/transport/internet/finalmask/fragment/conn.go @@ -0,0 +1,128 @@ +package fragment + +import ( + "net" + "time" + + "github.com/xtls/xray-core/common/crypto" +) + +type fragmentConn struct { + net.Conn + config *Config + count uint64 + + server bool +} + +func NewConnClient(c *Config, raw net.Conn, server bool) (net.Conn, error) { + conn := &fragmentConn{ + Conn: raw, + config: c, + + server: server, + } + + return conn, nil +} + +func NewConnServer(c *Config, raw net.Conn, server bool) (net.Conn, error) { + return NewConnClient(c, raw, server) +} + +func (c *fragmentConn) TcpMaskConn() {} + +func (c *fragmentConn) RawConn() net.Conn { + if c.server { + return c + } + return c.Conn +} + +func (c *fragmentConn) Splice() bool { + if c.server { + return false + } + return true +} + +func (c *fragmentConn) Write(p []byte) (n int, err error) { + c.count++ + + if c.config.PacketsFrom == 0 && c.config.PacketsTo == 1 { + if c.count != 1 || len(p) <= 5 || p[0] != 22 { + return c.Conn.Write(p) + } + recordLen := 5 + ((int(p[3]) << 8) | int(p[4])) + if len(p) < recordLen { + return c.Conn.Write(p) + } + data := p[5:recordLen] + buff := make([]byte, 2048) + var hello []byte + maxSplit := crypto.RandBetween(c.config.MaxSplitMin, c.config.MaxSplitMax) + var splitNum int64 + for from := 0; ; { + to := from + int(crypto.RandBetween(c.config.LengthMin, c.config.LengthMax)) + splitNum++ + if to > len(data) || (maxSplit > 0 && splitNum >= maxSplit) { + to = len(data) + } + l := to - from + if 5+l > len(buff) { + buff = make([]byte, 5+l) + } + copy(buff[:3], p) + copy(buff[5:], data[from:to]) + from = to + buff[3] = byte(l >> 8) + buff[4] = byte(l) + if c.config.DelayMax == 0 { + hello = append(hello, buff[:5+l]...) + } else { + _, err := c.Conn.Write(buff[:5+l]) + time.Sleep(time.Duration(crypto.RandBetween(c.config.DelayMin, c.config.DelayMax)) * time.Millisecond) + if err != nil { + return 0, err + } + } + if from == len(data) { + if len(hello) > 0 { + _, err := c.Conn.Write(hello) + if err != nil { + return 0, err + } + } + if len(p) > recordLen { + n, err := c.Conn.Write(p[recordLen:]) + if err != nil { + return recordLen + n, err + } + } + return len(p), nil + } + } + } + + if c.config.PacketsFrom != 0 && (c.count < uint64(c.config.PacketsFrom) || c.count > uint64(c.config.PacketsTo)) { + return c.Conn.Write(p) + } + maxSplit := crypto.RandBetween(c.config.MaxSplitMin, c.config.MaxSplitMax) + var splitNum int64 + for from := 0; ; { + to := from + int(crypto.RandBetween(c.config.LengthMin, c.config.LengthMax)) + splitNum++ + if to > len(p) || (maxSplit > 0 && splitNum >= maxSplit) { + to = len(p) + } + n, err := c.Conn.Write(p[from:to]) + from += n + if err != nil { + return from, err + } + time.Sleep(time.Duration(crypto.RandBetween(c.config.DelayMin, c.config.DelayMax)) * time.Millisecond) + if from >= len(p) { + return from, nil + } + } +} diff --git a/transport/internet/finalmask/header/custom/config.go b/transport/internet/finalmask/header/custom/config.go new file mode 100644 index 000000000000..1d72e336c0cd --- /dev/null +++ b/transport/internet/finalmask/header/custom/config.go @@ -0,0 +1,27 @@ +package custom + +import ( + "net" +) + +func (c *TCPConfig) TCP() { +} + +func (c *TCPConfig) WrapConnClient(raw net.Conn) (net.Conn, error) { + return NewConnClientTCP(c, raw) +} + +func (c *TCPConfig) WrapConnServer(raw net.Conn) (net.Conn, error) { + return NewConnServerTCP(c, raw) +} + +func (c *UDPConfig) UDP() { +} + +func (c *UDPConfig) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClientUDP(c, raw) +} + +func (c *UDPConfig) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServerUDP(c, raw) +} diff --git a/transport/internet/finalmask/header/custom/config.pb.go b/transport/internet/finalmask/header/custom/config.pb.go new file mode 100644 index 000000000000..ff06eb29143d --- /dev/null +++ b/transport/internet/finalmask/header/custom/config.pb.go @@ -0,0 +1,380 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.33.5 +// source: transport/internet/finalmask/header/custom/config.proto + +package custom + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type TCPItem struct { + state protoimpl.MessageState `protogen:"open.v1"` + DelayMin int64 `protobuf:"varint,1,opt,name=delay_min,json=delayMin,proto3" json:"delay_min,omitempty"` + DelayMax int64 `protobuf:"varint,2,opt,name=delay_max,json=delayMax,proto3" json:"delay_max,omitempty"` + Rand int32 `protobuf:"varint,3,opt,name=rand,proto3" json:"rand,omitempty"` + Packet []byte `protobuf:"bytes,4,opt,name=packet,proto3" json:"packet,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TCPItem) Reset() { + *x = TCPItem{} + mi := &file_transport_internet_finalmask_header_custom_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TCPItem) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TCPItem) ProtoMessage() {} + +func (x *TCPItem) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_header_custom_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TCPItem.ProtoReflect.Descriptor instead. +func (*TCPItem) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_header_custom_config_proto_rawDescGZIP(), []int{0} +} + +func (x *TCPItem) GetDelayMin() int64 { + if x != nil { + return x.DelayMin + } + return 0 +} + +func (x *TCPItem) GetDelayMax() int64 { + if x != nil { + return x.DelayMax + } + return 0 +} + +func (x *TCPItem) GetRand() int32 { + if x != nil { + return x.Rand + } + return 0 +} + +func (x *TCPItem) GetPacket() []byte { + if x != nil { + return x.Packet + } + return nil +} + +type TCPSequence struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sequence []*TCPItem `protobuf:"bytes,1,rep,name=sequence,proto3" json:"sequence,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TCPSequence) Reset() { + *x = TCPSequence{} + mi := &file_transport_internet_finalmask_header_custom_config_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TCPSequence) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TCPSequence) ProtoMessage() {} + +func (x *TCPSequence) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_header_custom_config_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TCPSequence.ProtoReflect.Descriptor instead. +func (*TCPSequence) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_header_custom_config_proto_rawDescGZIP(), []int{1} +} + +func (x *TCPSequence) GetSequence() []*TCPItem { + if x != nil { + return x.Sequence + } + return nil +} + +type TCPConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + Clients []*TCPSequence `protobuf:"bytes,1,rep,name=clients,proto3" json:"clients,omitempty"` + Servers []*TCPSequence `protobuf:"bytes,2,rep,name=servers,proto3" json:"servers,omitempty"` + Errors []*TCPSequence `protobuf:"bytes,3,rep,name=errors,proto3" json:"errors,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TCPConfig) Reset() { + *x = TCPConfig{} + mi := &file_transport_internet_finalmask_header_custom_config_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TCPConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TCPConfig) ProtoMessage() {} + +func (x *TCPConfig) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_header_custom_config_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TCPConfig.ProtoReflect.Descriptor instead. +func (*TCPConfig) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_header_custom_config_proto_rawDescGZIP(), []int{2} +} + +func (x *TCPConfig) GetClients() []*TCPSequence { + if x != nil { + return x.Clients + } + return nil +} + +func (x *TCPConfig) GetServers() []*TCPSequence { + if x != nil { + return x.Servers + } + return nil +} + +func (x *TCPConfig) GetErrors() []*TCPSequence { + if x != nil { + return x.Errors + } + return nil +} + +type UDPItem struct { + state protoimpl.MessageState `protogen:"open.v1"` + Rand int32 `protobuf:"varint,1,opt,name=rand,proto3" json:"rand,omitempty"` + Packet []byte `protobuf:"bytes,2,opt,name=packet,proto3" json:"packet,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *UDPItem) Reset() { + *x = UDPItem{} + mi := &file_transport_internet_finalmask_header_custom_config_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *UDPItem) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UDPItem) ProtoMessage() {} + +func (x *UDPItem) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_header_custom_config_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UDPItem.ProtoReflect.Descriptor instead. +func (*UDPItem) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_header_custom_config_proto_rawDescGZIP(), []int{3} +} + +func (x *UDPItem) GetRand() int32 { + if x != nil { + return x.Rand + } + return 0 +} + +func (x *UDPItem) GetPacket() []byte { + if x != nil { + return x.Packet + } + return nil +} + +type UDPConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + Client []*UDPItem `protobuf:"bytes,1,rep,name=client,proto3" json:"client,omitempty"` + Server []*UDPItem `protobuf:"bytes,2,rep,name=server,proto3" json:"server,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *UDPConfig) Reset() { + *x = UDPConfig{} + mi := &file_transport_internet_finalmask_header_custom_config_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *UDPConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UDPConfig) ProtoMessage() {} + +func (x *UDPConfig) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_header_custom_config_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UDPConfig.ProtoReflect.Descriptor instead. +func (*UDPConfig) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_header_custom_config_proto_rawDescGZIP(), []int{4} +} + +func (x *UDPConfig) GetClient() []*UDPItem { + if x != nil { + return x.Client + } + return nil +} + +func (x *UDPConfig) GetServer() []*UDPItem { + if x != nil { + return x.Server + } + return nil +} + +var File_transport_internet_finalmask_header_custom_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_header_custom_config_proto_rawDesc = "" + + "\n" + + "7transport/internet/finalmask/header/custom/config.proto\x12/xray.transport.internet.finalmask.header.custom\"o\n" + + "\aTCPItem\x12\x1b\n" + + "\tdelay_min\x18\x01 \x01(\x03R\bdelayMin\x12\x1b\n" + + "\tdelay_max\x18\x02 \x01(\x03R\bdelayMax\x12\x12\n" + + "\x04rand\x18\x03 \x01(\x05R\x04rand\x12\x16\n" + + "\x06packet\x18\x04 \x01(\fR\x06packet\"c\n" + + "\vTCPSequence\x12T\n" + + "\bsequence\x18\x01 \x03(\v28.xray.transport.internet.finalmask.header.custom.TCPItemR\bsequence\"\x91\x02\n" + + "\tTCPConfig\x12V\n" + + "\aclients\x18\x01 \x03(\v2<.xray.transport.internet.finalmask.header.custom.TCPSequenceR\aclients\x12V\n" + + "\aservers\x18\x02 \x03(\v2<.xray.transport.internet.finalmask.header.custom.TCPSequenceR\aservers\x12T\n" + + "\x06errors\x18\x03 \x03(\v2<.xray.transport.internet.finalmask.header.custom.TCPSequenceR\x06errors\"5\n" + + "\aUDPItem\x12\x12\n" + + "\x04rand\x18\x01 \x01(\x05R\x04rand\x12\x16\n" + + "\x06packet\x18\x02 \x01(\fR\x06packet\"\xaf\x01\n" + + "\tUDPConfig\x12P\n" + + "\x06client\x18\x01 \x03(\v28.xray.transport.internet.finalmask.header.custom.UDPItemR\x06client\x12P\n" + + "\x06server\x18\x02 \x03(\v28.xray.transport.internet.finalmask.header.custom.UDPItemR\x06serverB\xaf\x01\n" + + "3com.xray.transport.internet.finalmask.header.customP\x01ZDgithub.com/xtls/xray-core/transport/internet/finalmask/header/custom\xaa\x02/Xray.Transport.Internet.Finalmask.Header.Customb\x06proto3" + +var ( + file_transport_internet_finalmask_header_custom_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_header_custom_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_header_custom_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_header_custom_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_header_custom_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_custom_config_proto_rawDesc), len(file_transport_internet_finalmask_header_custom_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_header_custom_config_proto_rawDescData +} + +var file_transport_internet_finalmask_header_custom_config_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_transport_internet_finalmask_header_custom_config_proto_goTypes = []any{ + (*TCPItem)(nil), // 0: xray.transport.internet.finalmask.header.custom.TCPItem + (*TCPSequence)(nil), // 1: xray.transport.internet.finalmask.header.custom.TCPSequence + (*TCPConfig)(nil), // 2: xray.transport.internet.finalmask.header.custom.TCPConfig + (*UDPItem)(nil), // 3: xray.transport.internet.finalmask.header.custom.UDPItem + (*UDPConfig)(nil), // 4: xray.transport.internet.finalmask.header.custom.UDPConfig +} +var file_transport_internet_finalmask_header_custom_config_proto_depIdxs = []int32{ + 0, // 0: xray.transport.internet.finalmask.header.custom.TCPSequence.sequence:type_name -> xray.transport.internet.finalmask.header.custom.TCPItem + 1, // 1: xray.transport.internet.finalmask.header.custom.TCPConfig.clients:type_name -> xray.transport.internet.finalmask.header.custom.TCPSequence + 1, // 2: xray.transport.internet.finalmask.header.custom.TCPConfig.servers:type_name -> xray.transport.internet.finalmask.header.custom.TCPSequence + 1, // 3: xray.transport.internet.finalmask.header.custom.TCPConfig.errors:type_name -> xray.transport.internet.finalmask.header.custom.TCPSequence + 3, // 4: xray.transport.internet.finalmask.header.custom.UDPConfig.client:type_name -> xray.transport.internet.finalmask.header.custom.UDPItem + 3, // 5: xray.transport.internet.finalmask.header.custom.UDPConfig.server:type_name -> xray.transport.internet.finalmask.header.custom.UDPItem + 6, // [6:6] is the sub-list for method output_type + 6, // [6:6] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_header_custom_config_proto_init() } +func file_transport_internet_finalmask_header_custom_config_proto_init() { + if File_transport_internet_finalmask_header_custom_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_custom_config_proto_rawDesc), len(file_transport_internet_finalmask_header_custom_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 5, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_header_custom_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_header_custom_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_header_custom_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_header_custom_config_proto = out.File + file_transport_internet_finalmask_header_custom_config_proto_goTypes = nil + file_transport_internet_finalmask_header_custom_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/header/custom/config.proto b/transport/internet/finalmask/header/custom/config.proto new file mode 100644 index 000000000000..34314dee9e5c --- /dev/null +++ b/transport/internet/finalmask/header/custom/config.proto @@ -0,0 +1,34 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.header.custom; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Custom"; +option go_package = "github.com/xtls/xray-core/transport/internet/finalmask/header/custom"; +option java_package = "com.xray.transport.internet.finalmask.header.custom"; +option java_multiple_files = true; + +message TCPItem { + int64 delay_min = 1; + int64 delay_max = 2; + int32 rand = 3; + bytes packet = 4; +} + +message TCPSequence { + repeated TCPItem sequence = 1; +} + +message TCPConfig { + repeated TCPSequence clients = 1; + repeated TCPSequence servers = 2; + repeated TCPSequence errors = 3; +} + +message UDPItem { + int32 rand = 1; + bytes packet = 2; +} + +message UDPConfig { + repeated UDPItem client = 1; + repeated UDPItem server = 2; +} \ No newline at end of file diff --git a/transport/internet/finalmask/header/custom/tcp.go b/transport/internet/finalmask/header/custom/tcp.go new file mode 100644 index 000000000000..9f3e893a0eb6 --- /dev/null +++ b/transport/internet/finalmask/header/custom/tcp.go @@ -0,0 +1,248 @@ +package custom + +import ( + "bytes" + "crypto/rand" + "io" + "net" + "sync" + "time" + + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/crypto" + "github.com/xtls/xray-core/common/errors" +) + +type tcpCustomClient struct { + clients []*TCPSequence + servers []*TCPSequence +} + +type tcpCustomClientConn struct { + net.Conn + header *tcpCustomClient + + auth bool + wg sync.WaitGroup + once sync.Once +} + +func NewConnClientTCP(c *TCPConfig, raw net.Conn) (net.Conn, error) { + conn := &tcpCustomClientConn{ + Conn: raw, + header: &tcpCustomClient{ + clients: c.Clients, + servers: c.Servers, + }, + } + + conn.wg.Add(1) + + return conn, nil +} + +func (c *tcpCustomClientConn) TcpMaskConn() {} + +func (c *tcpCustomClientConn) RawConn() net.Conn { + // c.wg.Wait() + + return c.Conn +} + +func (c *tcpCustomClientConn) Splice() bool { + return true +} + +func (c *tcpCustomClientConn) Read(p []byte) (n int, err error) { + c.wg.Wait() + + if !c.auth { + return 0, errors.New("header auth failed") + } + + return c.Conn.Read(p) +} + +func (c *tcpCustomClientConn) Write(p []byte) (n int, err error) { + c.once.Do(func() { + i := 0 + j := 0 + for i = range c.header.clients { + if !writeSequence(c.Conn, c.header.clients[i]) { + c.wg.Done() + return + } + + if j < len(c.header.servers) { + if !readSequence(c.Conn, c.header.servers[j]) { + c.wg.Done() + return + } + j++ + } + } + + for j < len(c.header.servers) { + if !readSequence(c.Conn, c.header.servers[j]) { + c.wg.Done() + return + } + j++ + } + + c.auth = true + c.wg.Done() + }) + + c.wg.Wait() + + if !c.auth { + return 0, errors.New("header auth failed") + } + + return c.Conn.Write(p) +} + +type tcpCustomServer struct { + clients []*TCPSequence + servers []*TCPSequence + errors []*TCPSequence +} + +type tcpCustomServerConn struct { + net.Conn + header *tcpCustomServer + + auth bool + wg sync.WaitGroup + once sync.Once +} + +func NewConnServerTCP(c *TCPConfig, raw net.Conn) (net.Conn, error) { + conn := &tcpCustomServerConn{ + Conn: raw, + header: &tcpCustomServer{ + clients: c.Clients, + servers: c.Servers, + errors: c.Errors, + }, + } + + conn.wg.Add(1) + + return conn, nil +} + +func (c *tcpCustomServerConn) TcpMaskConn() {} + +func (c *tcpCustomServerConn) RawConn() net.Conn { + // c.wg.Wait() + + return c.Conn +} + +func (c *tcpCustomServerConn) Splice() bool { + return true +} + +func (c *tcpCustomServerConn) Read(p []byte) (n int, err error) { + c.once.Do(func() { + i := 0 + j := 0 + for i = range c.header.clients { + if !readSequence(c.Conn, c.header.clients[i]) { + if i < len(c.header.errors) { + writeSequence(c.Conn, c.header.errors[i]) + } + c.wg.Done() + return + } + + if j < len(c.header.servers) { + if !writeSequence(c.Conn, c.header.servers[j]) { + c.wg.Done() + return + } + j++ + } + } + + for j < len(c.header.servers) { + if !writeSequence(c.Conn, c.header.servers[j]) { + c.wg.Done() + return + } + j++ + } + + c.auth = true + c.wg.Done() + }) + + c.wg.Wait() + + if !c.auth { + return 0, errors.New("header auth failed") + } + + return c.Conn.Read(p) +} + +func (c *tcpCustomServerConn) Write(p []byte) (n int, err error) { + c.wg.Wait() + + if !c.auth { + return 0, errors.New("header auth failed") + } + + return c.Conn.Write(p) +} + +func readSequence(r io.Reader, sequence *TCPSequence) bool { + for _, item := range sequence.Sequence { + length := max(int(item.Rand), len(item.Packet)) + buf := make([]byte, length) + n, err := io.ReadFull(r, buf) + if err != nil { + return false + } + if item.Rand > 0 && n != length { + return false + } + if len(item.Packet) > 0 && !bytes.Equal(item.Packet, buf[:n]) { + return false + } + } + return true +} + +func writeSequence(w io.Writer, sequence *TCPSequence) bool { + var merged []byte + for _, item := range sequence.Sequence { + if item.DelayMax > 0 { + if len(merged) > 0 { + _, err := w.Write(merged) + if err != nil { + return false + } + merged = nil + } + time.Sleep(time.Duration(crypto.RandBetween(item.DelayMin, item.DelayMax)) * time.Millisecond) + } + if item.Rand > 0 { + buf := make([]byte, item.Rand) + common.Must2(rand.Read(buf)) + merged = append(merged, buf...) + } else { + merged = append(merged, item.Packet...) + } + } + if len(merged) > 0 { + _, err := w.Write(merged) + if err != nil { + return false + } + merged = nil + } + return true +} diff --git a/transport/internet/finalmask/header/custom/udp.go b/transport/internet/finalmask/header/custom/udp.go new file mode 100644 index 000000000000..efecf90b123d --- /dev/null +++ b/transport/internet/finalmask/header/custom/udp.go @@ -0,0 +1,250 @@ +package custom + +import ( + "bytes" + "context" + "crypto/rand" + "net" + + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" +) + +type udpCustomClient struct { + client []*UDPItem + server []*UDPItem + merged []byte +} + +func (h *udpCustomClient) Serialize(b []byte) { + index := 0 + for _, item := range h.client { + if item.Rand > 0 { + common.Must2(rand.Read(h.merged[index : index+int(item.Rand)])) + index += int(item.Rand) + } else { + index += len(item.Packet) + } + } + copy(b, h.merged) +} + +func (h *udpCustomClient) Match(b []byte) bool { + if len(b) < len(h.merged) { + return false + } + + data := b + match := true + + for _, item := range h.server { + length := max(int(item.Rand), len(item.Packet)) + + if len(item.Packet) > 0 && !bytes.Equal(item.Packet, data[:length]) { + match = false + break + } + + data = data[length:] + } + + return match +} + +type udpCustomClientConn struct { + net.PacketConn + header *udpCustomClient +} + +func NewConnClientUDP(c *UDPConfig, raw net.PacketConn) (net.PacketConn, error) { + conn := &udpCustomClientConn{ + PacketConn: raw, + header: &udpCustomClient{ + client: c.Client, + server: c.Server, + }, + } + + index := 0 + for _, item := range conn.header.client { + if item.Rand > 0 { + conn.header.merged = append(conn.header.merged, make([]byte, item.Rand)...) + index += int(item.Rand) + } else { + conn.header.merged = append(conn.header.merged, item.Packet...) + index += len(item.Packet) + } + } + + return conn, nil +} + +func (c *udpCustomClientConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + buf := p + if len(p) < finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } + + n, addr, err = c.PacketConn.ReadFrom(buf) + if err != nil || n == 0 { + return n, addr, err + } + + if !c.header.Match(buf[:n]) { + errors.LogDebug(context.Background(), addr, " mask read err header mismatch") + return 0, addr, nil + } + + if len(p) < n-len(c.header.merged) { + errors.LogDebug(context.Background(), addr, " mask read err short buffer ", len(p), " ", n-len(c.header.merged)) + return 0, addr, nil + } + + copy(p, buf[len(c.header.merged):n]) + + return n - len(c.header.merged), addr, nil +} + +func (c *udpCustomClientConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if len(c.header.merged)+len(p) > finalmask.UDPSize { + errors.LogDebug(context.Background(), addr, " mask write err short write ", len(c.header.merged)+len(p), " ", finalmask.UDPSize) + return 0, nil + } + + var buf []byte + if cap(p) != finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } else { + buf = p[:len(c.header.merged)+len(p)] + } + + copy(buf[len(c.header.merged):], p) + c.header.Serialize(buf) + + _, err = c.PacketConn.WriteTo(buf[:len(c.header.merged)+len(p)], addr) + if err != nil { + return 0, err + } + + return len(p), nil +} + +type udpCustomServer struct { + client []*UDPItem + server []*UDPItem + merged []byte +} + +func (h *udpCustomServer) Serialize(b []byte) { + index := 0 + for _, item := range h.server { + if item.Rand > 0 { + common.Must2(rand.Read(h.merged[index : index+int(item.Rand)])) + index += int(item.Rand) + } else { + index += len(item.Packet) + } + } + copy(b, h.merged) +} + +func (h *udpCustomServer) Match(b []byte) bool { + if len(b) < len(h.merged) { + return false + } + + data := b + match := true + + for _, item := range h.client { + length := max(int(item.Rand), len(item.Packet)) + + if len(item.Packet) > 0 && !bytes.Equal(item.Packet, data[:length]) { + match = false + break + } + + data = data[length:] + } + + return match +} + +type udpCustomServerConn struct { + net.PacketConn + header *udpCustomServer +} + +func NewConnServerUDP(c *UDPConfig, raw net.PacketConn) (net.PacketConn, error) { + conn := &udpCustomServerConn{ + PacketConn: raw, + header: &udpCustomServer{ + client: c.Client, + server: c.Server, + }, + } + + index := 0 + for _, item := range conn.header.server { + if item.Rand > 0 { + conn.header.merged = append(conn.header.merged, make([]byte, item.Rand)...) + index += int(item.Rand) + } else { + conn.header.merged = append(conn.header.merged, item.Packet...) + index += len(item.Packet) + } + } + + return conn, nil +} + +func (c *udpCustomServerConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + buf := p + if len(p) < finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } + + n, addr, err = c.PacketConn.ReadFrom(buf) + if err != nil || n == 0 { + return n, addr, err + } + + if !c.header.Match(buf[:n]) { + errors.LogDebug(context.Background(), addr, " mask read err header mismatch") + return 0, addr, nil + } + + if len(p) < n-len(c.header.merged) { + errors.LogDebug(context.Background(), addr, " mask read err short buffer ", len(p), " ", n-len(c.header.merged)) + return 0, addr, nil + } + + copy(p, buf[len(c.header.merged):n]) + + return n - len(c.header.merged), addr, nil +} + +func (c *udpCustomServerConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if len(c.header.merged)+len(p) > finalmask.UDPSize { + errors.LogDebug(context.Background(), addr, " mask write err short write ", len(c.header.merged)+len(p), " ", finalmask.UDPSize) + return 0, nil + } + + var buf []byte + if cap(p) != finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } else { + buf = p[:len(c.header.merged)+len(p)] + } + + copy(buf[len(c.header.merged):], p) + c.header.Serialize(buf) + + _, err = c.PacketConn.WriteTo(buf[:len(c.header.merged)+len(p)], addr) + if err != nil { + return 0, err + } + + return len(p), nil +} diff --git a/transport/internet/finalmask/header/dns/config.go b/transport/internet/finalmask/header/dns/config.go index d5aa5cc399e0..7be9eb9aae74 100644 --- a/transport/internet/finalmask/header/dns/config.go +++ b/transport/internet/finalmask/header/dns/config.go @@ -7,10 +7,10 @@ import ( func (c *Config) UDP() { } -func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClient(c, raw) } -func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnServer(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServer(c, raw) } diff --git a/transport/internet/finalmask/header/dns/conn.go b/transport/internet/finalmask/header/dns/conn.go index 0baa8fec9972..a60aa2e4f571 100644 --- a/transport/internet/finalmask/header/dns/conn.go +++ b/transport/internet/finalmask/header/dns/conn.go @@ -1,14 +1,13 @@ package dns import ( + "context" "encoding/binary" - "io" "net" - sync "sync" - "time" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" ) func packDomainName(s string, msg []byte) (off1 int, err error) { @@ -81,8 +80,8 @@ type dns struct { header []byte } -func (h *dns) Size() int32 { - return int32(len(h.header)) +func (h *dns) Size() int { + return len(h.header) } func (h *dns) Serialize(b []byte) { @@ -91,19 +90,11 @@ func (h *dns) Serialize(b []byte) { } type dnsConn struct { - first bool - leaveSize int32 - - conn net.PacketConn + net.PacketConn header *dns - - readBuf []byte - readMutex sync.Mutex - writeBuf []byte - writeMutex sync.Mutex } -func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { +func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { var header []byte header = binary.BigEndian.AppendUint16(header, 0x0000) // Transaction ID header = binary.BigEndian.AppendUint16(header, 0x0100) // Flags: Standard query @@ -121,121 +112,65 @@ func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) ( header = binary.BigEndian.AppendUint16(header, 0x0001) // Class: IN conn := &dnsConn{ - first: first, - leaveSize: leaveSize, - - conn: raw, + PacketConn: raw, header: &dns{ header: header, }, } - if first { - conn.readBuf = make([]byte, 8192) - conn.writeBuf = make([]byte, 8192) - } - return conn, nil } -func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) -} - -func (c *dnsConn) Size() int32 { - return c.header.Size() +func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { + return NewConnClient(c, raw) } func (c *dnsConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if c.first { - c.readMutex.Lock() - - n, addr, err = c.conn.ReadFrom(c.readBuf) - if err != nil { - c.readMutex.Unlock() - return n, addr, err - } - - if n < int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) - } - - if len(p) < n-int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) - } - - copy(p, c.readBuf[c.Size():n]) - - c.readMutex.Unlock() - return n - int(c.Size()), addr, err + buf := p + if len(p) < finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) } - n, addr, err = c.conn.ReadFrom(p) - if err != nil { + n, addr, err = c.PacketConn.ReadFrom(buf) + if err != nil || n == 0 { return n, addr, err } - if n < int(c.Size()) { - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + if n < c.header.Size() { + errors.LogDebug(context.Background(), addr, " mask read err header mismatch") + return 0, addr, nil } - copy(p, p[c.Size():n]) - - return n - int(c.Size()), addr, err -} - -func (c *dnsConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.first { - if c.leaveSize+c.Size()+int32(len(p)) > 8192 { - return 0, errors.New("too many masks") - } - - c.writeMutex.Lock() - - n = copy(c.writeBuf[c.leaveSize+c.Size():], p) - n += int(c.leaveSize) + int(c.Size()) - - c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()]) - - nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) - - if err != nil { - c.writeMutex.Unlock() - return 0, err - } - - if nn != n { - c.writeMutex.Unlock() - return 0, errors.New("nn != n") - } - - c.writeMutex.Unlock() - return len(p), nil + if len(p) < n-c.header.Size() { + errors.LogDebug(context.Background(), addr, " mask read err short buffer ", len(p), " ", n-c.header.Size()) + return 0, addr, nil } - c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()]) + copy(p, buf[c.header.Size():n]) - return c.conn.WriteTo(p, addr) + return n - c.header.Size(), addr, nil } -func (c *dnsConn) Close() error { - return c.conn.Close() -} +func (c *dnsConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.header.Size()+len(p) > finalmask.UDPSize { + errors.LogDebug(context.Background(), addr, " mask write err short write ", c.header.Size()+len(p), " ", finalmask.UDPSize) + return 0, nil + } -func (c *dnsConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} + var buf []byte + if cap(p) != finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } else { + buf = p[:c.header.Size()+len(p)] + } -func (c *dnsConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} + copy(buf[c.header.Size():], p) + c.header.Serialize(buf) -func (c *dnsConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} + _, err = c.PacketConn.WriteTo(buf[:c.header.Size()+len(p)], addr) + if err != nil { + return 0, err + } -func (c *dnsConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return len(p), nil } diff --git a/transport/internet/finalmask/header/dtls/config.go b/transport/internet/finalmask/header/dtls/config.go index ccce33decfdb..02d102d06e59 100644 --- a/transport/internet/finalmask/header/dtls/config.go +++ b/transport/internet/finalmask/header/dtls/config.go @@ -7,10 +7,10 @@ import ( func (c *Config) UDP() { } -func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClient(c, raw) } -func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnServer(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServer(c, raw) } diff --git a/transport/internet/finalmask/header/dtls/conn.go b/transport/internet/finalmask/header/dtls/conn.go index 8b12f83e3ea8..0f7c16e8e386 100644 --- a/transport/internet/finalmask/header/dtls/conn.go +++ b/transport/internet/finalmask/header/dtls/conn.go @@ -1,13 +1,12 @@ package dtls import ( - "io" + "context" "net" - sync "sync" - "time" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" ) type dtls struct { @@ -16,7 +15,7 @@ type dtls struct { sequence uint32 } -func (*dtls) Size() int32 { +func (*dtls) Size() int { return 1 + 2 + 2 + 6 + 2 } @@ -42,24 +41,13 @@ func (h *dtls) Serialize(b []byte) { } type dtlsConn struct { - first bool - leaveSize int32 - - conn net.PacketConn + net.PacketConn header *dtls - - readBuf []byte - readMutex sync.Mutex - writeBuf []byte - writeMutex sync.Mutex } -func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { +func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { conn := &dtlsConn{ - first: first, - leaveSize: leaveSize, - - conn: raw, + PacketConn: raw, header: &dtls{ epoch: dice.RollUint16(), sequence: 0, @@ -67,112 +55,59 @@ func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) ( }, } - if first { - conn.readBuf = make([]byte, 8192) - conn.writeBuf = make([]byte, 8192) - } - return conn, nil } -func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) -} - -func (c *dtlsConn) Size() int32 { - return c.header.Size() +func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { + return NewConnClient(c, raw) } func (c *dtlsConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if c.first { - c.readMutex.Lock() - - n, addr, err = c.conn.ReadFrom(c.readBuf) - if err != nil { - c.readMutex.Unlock() - return n, addr, err - } - - if n < int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) - } - - if len(p) < n-int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) - } - - copy(p, c.readBuf[c.Size():n]) - - c.readMutex.Unlock() - return n - int(c.Size()), addr, err + buf := p + if len(p) < finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) } - n, addr, err = c.conn.ReadFrom(p) - if err != nil { + n, addr, err = c.PacketConn.ReadFrom(buf) + if err != nil || n == 0 { return n, addr, err } - if n < int(c.Size()) { - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + if n < c.header.Size() { + errors.LogDebug(context.Background(), addr, " mask read err header mismatch") + return 0, addr, nil } - copy(p, p[c.Size():n]) - - return n - int(c.Size()), addr, err -} - -func (c *dtlsConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.first { - if c.leaveSize+c.Size()+int32(len(p)) > 8192 { - return 0, errors.New("too many masks") - } - - c.writeMutex.Lock() - - n = copy(c.writeBuf[c.leaveSize+c.Size():], p) - n += int(c.leaveSize) + int(c.Size()) - - c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()]) - - nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) - - if err != nil { - c.writeMutex.Unlock() - return 0, err - } - - if nn != n { - c.writeMutex.Unlock() - return 0, errors.New("nn != n") - } - - c.writeMutex.Unlock() - return len(p), nil + if len(p) < n-c.header.Size() { + errors.LogDebug(context.Background(), addr, " mask read err short buffer ", len(p), " ", n-c.header.Size()) + return 0, addr, nil } - c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()]) + copy(p, buf[c.header.Size():n]) - return c.conn.WriteTo(p, addr) + return n - c.header.Size(), addr, nil } -func (c *dtlsConn) Close() error { - return c.conn.Close() -} +func (c *dtlsConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.header.Size()+len(p) > finalmask.UDPSize { + errors.LogDebug(context.Background(), addr, " mask write err short write ", c.header.Size()+len(p), " ", finalmask.UDPSize) + return 0, nil + } -func (c *dtlsConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} + var buf []byte + if cap(p) != finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } else { + buf = p[:c.header.Size()+len(p)] + } -func (c *dtlsConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} + copy(buf[c.header.Size():], p) + c.header.Serialize(buf) -func (c *dtlsConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} + _, err = c.PacketConn.WriteTo(buf[:c.header.Size()+len(p)], addr) + if err != nil { + return 0, err + } -func (c *dtlsConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return len(p), nil } diff --git a/transport/internet/finalmask/header/srtp/config.go b/transport/internet/finalmask/header/srtp/config.go index 45def61629ca..875bf899f27d 100644 --- a/transport/internet/finalmask/header/srtp/config.go +++ b/transport/internet/finalmask/header/srtp/config.go @@ -7,10 +7,10 @@ import ( func (c *Config) UDP() { } -func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClient(c, raw) } -func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnServer(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServer(c, raw) } diff --git a/transport/internet/finalmask/header/srtp/conn.go b/transport/internet/finalmask/header/srtp/conn.go index bb03bce70953..8dabcd5f0485 100644 --- a/transport/internet/finalmask/header/srtp/conn.go +++ b/transport/internet/finalmask/header/srtp/conn.go @@ -1,14 +1,13 @@ package srtp import ( + "context" "encoding/binary" - "io" "net" - sync "sync" - "time" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" ) type srtp struct { @@ -16,7 +15,7 @@ type srtp struct { number uint16 } -func (*srtp) Size() int32 { +func (*srtp) Size() int { return 4 } @@ -27,136 +26,72 @@ func (h *srtp) Serialize(b []byte) { } type srtpConn struct { - first bool - leaveSize int32 - - conn net.PacketConn + net.PacketConn header *srtp - - readBuf []byte - readMutex sync.Mutex - writeBuf []byte - writeMutex sync.Mutex } -func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { +func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { conn := &srtpConn{ - first: first, - leaveSize: leaveSize, - - conn: raw, + PacketConn: raw, header: &srtp{ header: 0xB5E8, number: dice.RollUint16(), }, } - if first { - conn.readBuf = make([]byte, 8192) - conn.writeBuf = make([]byte, 8192) - } - return conn, nil } -func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) -} - -func (c *srtpConn) Size() int32 { - return c.header.Size() +func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { + return NewConnClient(c, raw) } func (c *srtpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if c.first { - c.readMutex.Lock() - - n, addr, err = c.conn.ReadFrom(c.readBuf) - if err != nil { - c.readMutex.Unlock() - return n, addr, err - } - - if n < int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) - } - - if len(p) < n-int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) - } - - copy(p, c.readBuf[c.Size():n]) - - c.readMutex.Unlock() - return n - int(c.Size()), addr, err + buf := p + if len(p) < finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) } - n, addr, err = c.conn.ReadFrom(p) - if err != nil { + n, addr, err = c.PacketConn.ReadFrom(buf) + if err != nil || n == 0 { return n, addr, err } - if n < int(c.Size()) { - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + if n < c.header.Size() { + errors.LogDebug(context.Background(), addr, " mask read err header mismatch") + return 0, addr, nil } - copy(p, p[c.Size():n]) - - return n - int(c.Size()), addr, err -} - -func (c *srtpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.first { - if c.leaveSize+c.Size()+int32(len(p)) > 8192 { - return 0, errors.New("too many masks") - } - - c.writeMutex.Lock() - - n = copy(c.writeBuf[c.leaveSize+c.Size():], p) - n += int(c.leaveSize) + int(c.Size()) - - c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()]) - - nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) - - if err != nil { - c.writeMutex.Unlock() - return 0, err - } - - if nn != n { - c.writeMutex.Unlock() - return 0, errors.New("nn != n") - } - - c.writeMutex.Unlock() - return len(p), nil + if len(p) < n-c.header.Size() { + errors.LogDebug(context.Background(), addr, " mask read err short buffer ", len(p), " ", n-c.header.Size()) + return 0, addr, nil } - c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()]) + copy(p, buf[c.header.Size():n]) - return c.conn.WriteTo(p, addr) + return n - c.header.Size(), addr, nil } -func (c *srtpConn) Close() error { - return c.conn.Close() -} +func (c *srtpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.header.Size()+len(p) > finalmask.UDPSize { + errors.LogDebug(context.Background(), addr, " mask write err short write ", c.header.Size()+len(p), " ", finalmask.UDPSize) + return 0, nil + } -func (c *srtpConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} + var buf []byte + if cap(p) != finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } else { + buf = p[:c.header.Size()+len(p)] + } -func (c *srtpConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} + copy(buf[c.header.Size():], p) + c.header.Serialize(buf) -func (c *srtpConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} + _, err = c.PacketConn.WriteTo(buf[:c.header.Size()+len(p)], addr) + if err != nil { + return 0, err + } -func (c *srtpConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return len(p), nil } diff --git a/transport/internet/finalmask/header/utp/config.go b/transport/internet/finalmask/header/utp/config.go index a579d48366e8..804f462acefc 100644 --- a/transport/internet/finalmask/header/utp/config.go +++ b/transport/internet/finalmask/header/utp/config.go @@ -7,10 +7,10 @@ import ( func (c *Config) UDP() { } -func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClient(c, raw) } -func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnServer(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServer(c, raw) } diff --git a/transport/internet/finalmask/header/utp/conn.go b/transport/internet/finalmask/header/utp/conn.go index 5d48193ca74e..7647e0490f3f 100644 --- a/transport/internet/finalmask/header/utp/conn.go +++ b/transport/internet/finalmask/header/utp/conn.go @@ -1,14 +1,13 @@ package utp import ( + "context" "encoding/binary" - "io" "net" - sync "sync" - "time" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" ) type utp struct { @@ -17,7 +16,7 @@ type utp struct { connectionID uint16 } -func (*utp) Size() int32 { +func (*utp) Size() int { return 4 } @@ -28,24 +27,13 @@ func (h *utp) Serialize(b []byte) { } type utpConn struct { - first bool - leaveSize int32 - - conn net.PacketConn + net.PacketConn header *utp - - readBuf []byte - readMutex sync.Mutex - writeBuf []byte - writeMutex sync.Mutex } -func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { +func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { conn := &utpConn{ - first: first, - leaveSize: leaveSize, - - conn: raw, + PacketConn: raw, header: &utp{ header: 1, extension: 0, @@ -53,112 +41,59 @@ func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) ( }, } - if first { - conn.readBuf = make([]byte, 8192) - conn.writeBuf = make([]byte, 8192) - } - return conn, nil } -func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) -} - -func (c *utpConn) Size() int32 { - return c.header.Size() +func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { + return NewConnClient(c, raw) } func (c *utpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if c.first { - c.readMutex.Lock() - - n, addr, err = c.conn.ReadFrom(c.readBuf) - if err != nil { - c.readMutex.Unlock() - return n, addr, err - } - - if n < int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) - } - - if len(p) < n-int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) - } - - copy(p, c.readBuf[c.Size():n]) - - c.readMutex.Unlock() - return n - int(c.Size()), addr, err + buf := p + if len(p) < finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) } - n, addr, err = c.conn.ReadFrom(p) - if err != nil { + n, addr, err = c.PacketConn.ReadFrom(buf) + if err != nil || n == 0 { return n, addr, err } - if n < int(c.Size()) { - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + if n < c.header.Size() { + errors.LogDebug(context.Background(), addr, " mask read err header mismatch") + return 0, addr, nil } - copy(p, p[c.Size():n]) - - return n - int(c.Size()), addr, err -} - -func (c *utpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.first { - if c.leaveSize+c.Size()+int32(len(p)) > 8192 { - return 0, errors.New("too many masks") - } - - c.writeMutex.Lock() - - n = copy(c.writeBuf[c.leaveSize+c.Size():], p) - n += int(c.leaveSize) + int(c.Size()) - - c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()]) - - nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) - - if err != nil { - c.writeMutex.Unlock() - return 0, err - } - - if nn != n { - c.writeMutex.Unlock() - return 0, errors.New("nn != n") - } - - c.writeMutex.Unlock() - return len(p), nil + if len(p) < n-c.header.Size() { + errors.LogDebug(context.Background(), addr, " mask read err short buffer ", len(p), " ", n-c.header.Size()) + return 0, addr, nil } - c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()]) + copy(p, buf[c.header.Size():n]) - return c.conn.WriteTo(p, addr) + return n - c.header.Size(), addr, nil } -func (c *utpConn) Close() error { - return c.conn.Close() -} +func (c *utpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.header.Size()+len(p) > finalmask.UDPSize { + errors.LogDebug(context.Background(), addr, " mask write err short write ", c.header.Size()+len(p), " ", finalmask.UDPSize) + return 0, nil + } -func (c *utpConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} + var buf []byte + if cap(p) != finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } else { + buf = p[:c.header.Size()+len(p)] + } -func (c *utpConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} + copy(buf[c.header.Size():], p) + c.header.Serialize(buf) -func (c *utpConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} + _, err = c.PacketConn.WriteTo(buf[:c.header.Size()+len(p)], addr) + if err != nil { + return 0, err + } -func (c *utpConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return len(p), nil } diff --git a/transport/internet/finalmask/header/wechat/config.go b/transport/internet/finalmask/header/wechat/config.go index 34971ace3a72..8cd6ad3f9b73 100644 --- a/transport/internet/finalmask/header/wechat/config.go +++ b/transport/internet/finalmask/header/wechat/config.go @@ -7,10 +7,10 @@ import ( func (c *Config) UDP() { } -func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClient(c, raw) } -func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnServer(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServer(c, raw) } diff --git a/transport/internet/finalmask/header/wechat/conn.go b/transport/internet/finalmask/header/wechat/conn.go index f1f56c49d72d..157c1947112f 100644 --- a/transport/internet/finalmask/header/wechat/conn.go +++ b/transport/internet/finalmask/header/wechat/conn.go @@ -1,21 +1,20 @@ package wechat import ( + "context" "encoding/binary" - "io" "net" - sync "sync" - "time" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" ) type wechat struct { sn uint32 } -func (*wechat) Size() int32 { +func (*wechat) Size() int { return 13 } @@ -34,135 +33,71 @@ func (h *wechat) Serialize(b []byte) { } type wechatConn struct { - first bool - leaveSize int32 - - conn net.PacketConn + net.PacketConn header *wechat - - readBuf []byte - readMutex sync.Mutex - writeBuf []byte - writeMutex sync.Mutex } -func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { +func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { conn := &wechatConn{ - first: first, - leaveSize: leaveSize, - - conn: raw, + PacketConn: raw, header: &wechat{ sn: uint32(dice.RollUint16()), }, } - if first { - conn.readBuf = make([]byte, 8192) - conn.writeBuf = make([]byte, 8192) - } - return conn, nil } -func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) -} - -func (c *wechatConn) Size() int32 { - return c.header.Size() +func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { + return NewConnClient(c, raw) } func (c *wechatConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if c.first { - c.readMutex.Lock() - - n, addr, err = c.conn.ReadFrom(c.readBuf) - if err != nil { - c.readMutex.Unlock() - return n, addr, err - } - - if n < int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) - } - - if len(p) < n-int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) - } - - copy(p, c.readBuf[c.Size():n]) - - c.readMutex.Unlock() - return n - int(c.Size()), addr, err + buf := p + if len(p) < finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) } - n, addr, err = c.conn.ReadFrom(p) - if err != nil { + n, addr, err = c.PacketConn.ReadFrom(buf) + if err != nil || n == 0 { return n, addr, err } - if n < int(c.Size()) { - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + if n < c.header.Size() { + errors.LogDebug(context.Background(), addr, " mask read err header mismatch") + return 0, addr, nil } - copy(p, p[c.Size():n]) - - return n - int(c.Size()), addr, err -} - -func (c *wechatConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.first { - if c.leaveSize+c.Size()+int32(len(p)) > 8192 { - return 0, errors.New("too many masks") - } - - c.writeMutex.Lock() - - n = copy(c.writeBuf[c.leaveSize+c.Size():], p) - n += int(c.leaveSize) + int(c.Size()) - - c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()]) - - nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) - - if err != nil { - c.writeMutex.Unlock() - return 0, err - } - - if nn != n { - c.writeMutex.Unlock() - return 0, errors.New("nn != n") - } - - c.writeMutex.Unlock() - return len(p), nil + if len(p) < n-c.header.Size() { + errors.LogDebug(context.Background(), addr, " mask read err short buffer ", len(p), " ", n-c.header.Size()) + return 0, addr, nil } - c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()]) + copy(p, buf[c.header.Size():n]) - return c.conn.WriteTo(p, addr) + return n - c.header.Size(), addr, nil } -func (c *wechatConn) Close() error { - return c.conn.Close() -} +func (c *wechatConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.header.Size()+len(p) > finalmask.UDPSize { + errors.LogDebug(context.Background(), addr, " mask write err short write ", c.header.Size()+len(p), " ", finalmask.UDPSize) + return 0, nil + } -func (c *wechatConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} + var buf []byte + if cap(p) != finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } else { + buf = p[:c.header.Size()+len(p)] + } -func (c *wechatConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} + copy(buf[c.header.Size():], p) + c.header.Serialize(buf) -func (c *wechatConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} + _, err = c.PacketConn.WriteTo(buf[:c.header.Size()+len(p)], addr) + if err != nil { + return 0, err + } -func (c *wechatConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return len(p), nil } diff --git a/transport/internet/finalmask/header/wireguard/config.go b/transport/internet/finalmask/header/wireguard/config.go index 5eeee34ba6ae..b159f438a994 100644 --- a/transport/internet/finalmask/header/wireguard/config.go +++ b/transport/internet/finalmask/header/wireguard/config.go @@ -7,10 +7,10 @@ import ( func (c *Config) UDP() { } -func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClient(c, raw) } -func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnServer(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServer(c, raw) } diff --git a/transport/internet/finalmask/header/wireguard/conn.go b/transport/internet/finalmask/header/wireguard/conn.go index f4bf17748242..2d310d282f36 100644 --- a/transport/internet/finalmask/header/wireguard/conn.go +++ b/transport/internet/finalmask/header/wireguard/conn.go @@ -1,17 +1,16 @@ package wireguard import ( - "io" + "context" "net" - sync "sync" - "time" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" ) type wireguare struct{} -func (*wireguare) Size() int32 { +func (*wireguare) Size() int { return 4 } @@ -23,133 +22,69 @@ func (h *wireguare) Serialize(b []byte) { } type wireguareConn struct { - first bool - leaveSize int32 - - conn net.PacketConn + net.PacketConn header *wireguare - - readBuf []byte - readMutex sync.Mutex - writeBuf []byte - writeMutex sync.Mutex } -func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { +func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { conn := &wireguareConn{ - first: first, - leaveSize: leaveSize, - - conn: raw, - header: &wireguare{}, - } - - if first { - conn.readBuf = make([]byte, 8192) - conn.writeBuf = make([]byte, 8192) + PacketConn: raw, + header: &wireguare{}, } return conn, nil } -func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) -} - -func (c *wireguareConn) Size() int32 { - return c.header.Size() +func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { + return NewConnClient(c, raw) } func (c *wireguareConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if c.first { - c.readMutex.Lock() - - n, addr, err = c.conn.ReadFrom(c.readBuf) - if err != nil { - c.readMutex.Unlock() - return n, addr, err - } - - if n < int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) - } - - if len(p) < n-int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) - } - - copy(p, c.readBuf[c.Size():n]) - - c.readMutex.Unlock() - return n - int(c.Size()), addr, err + buf := p + if len(p) < finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) } - n, addr, err = c.conn.ReadFrom(p) - if err != nil { + n, addr, err = c.PacketConn.ReadFrom(buf) + if err != nil || n == 0 { return n, addr, err } - if n < int(c.Size()) { - return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + if n < c.header.Size() { + errors.LogDebug(context.Background(), addr, " mask read err header mismatch") + return 0, addr, nil } - copy(p, p[c.Size():n]) - - return n - int(c.Size()), addr, err -} - -func (c *wireguareConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.first { - if c.leaveSize+c.Size()+int32(len(p)) > 8192 { - return 0, errors.New("too many masks") - } - - c.writeMutex.Lock() - - n = copy(c.writeBuf[c.leaveSize+c.Size():], p) - n += int(c.leaveSize) + int(c.Size()) - - c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()]) - - nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) - - if err != nil { - c.writeMutex.Unlock() - return 0, err - } - - if nn != n { - c.writeMutex.Unlock() - return 0, errors.New("nn != n") - } - - c.writeMutex.Unlock() - return len(p), nil + if len(p) < n-c.header.Size() { + errors.LogDebug(context.Background(), addr, " mask read err short buffer ", len(p), " ", n-c.header.Size()) + return 0, addr, nil } - c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()]) + copy(p, buf[c.header.Size():n]) - return c.conn.WriteTo(p, addr) + return n - c.header.Size(), addr, nil } -func (c *wireguareConn) Close() error { - return c.conn.Close() -} +func (c *wireguareConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.header.Size()+len(p) > finalmask.UDPSize { + errors.LogDebug(context.Background(), addr, " mask write err short write ", c.header.Size()+len(p), " ", finalmask.UDPSize) + return 0, nil + } -func (c *wireguareConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} + var buf []byte + if cap(p) != finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } else { + buf = p[:c.header.Size()+len(p)] + } -func (c *wireguareConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} + copy(buf[c.header.Size():], p) + c.header.Serialize(buf) -func (c *wireguareConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} + _, err = c.PacketConn.WriteTo(buf[:c.header.Size()+len(p)], addr) + if err != nil { + return 0, err + } -func (c *wireguareConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return len(p), nil } diff --git a/transport/internet/finalmask/mkcp/aes128gcm/config.go b/transport/internet/finalmask/mkcp/aes128gcm/config.go index 595dd4ee93cb..da3459c6d1e5 100644 --- a/transport/internet/finalmask/mkcp/aes128gcm/config.go +++ b/transport/internet/finalmask/mkcp/aes128gcm/config.go @@ -7,10 +7,10 @@ import ( func (c *Config) UDP() { } -func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClient(c, raw) } -func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnServer(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServer(c, raw) } diff --git a/transport/internet/finalmask/mkcp/aes128gcm/conn.go b/transport/internet/finalmask/mkcp/aes128gcm/conn.go index 9f36fc2adeaf..89fd90317b05 100644 --- a/transport/internet/finalmask/mkcp/aes128gcm/conn.go +++ b/transport/internet/finalmask/mkcp/aes128gcm/conn.go @@ -1,99 +1,77 @@ package aes128gcm import ( + "context" "crypto/cipher" "crypto/rand" "crypto/sha256" - "io" "net" - sync "sync" - "time" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/crypto" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" ) type aes128gcmConn struct { - first bool - leaveSize int32 - - conn net.PacketConn + net.PacketConn aead cipher.AEAD - - readBuf []byte - readMutex sync.Mutex - writeBuf []byte - writeMutex sync.Mutex } -func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { +func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { hashedPsk := sha256.Sum256([]byte(c.Password)) conn := &aes128gcmConn{ - first: first, - leaveSize: leaveSize, - - conn: raw, - aead: crypto.NewAesGcm(hashedPsk[:16]), - } - - if first { - conn.readBuf = make([]byte, 8192) - conn.writeBuf = make([]byte, 8192) + PacketConn: raw, + aead: crypto.NewAesGcm(hashedPsk[:16]), } return conn, nil } -func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) -} - -func (c *aes128gcmConn) Size() int32 { - return int32(c.aead.NonceSize()) + int32(c.aead.Overhead()) +func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { + return NewConnClient(c, raw) } func (c *aes128gcmConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if c.first { - c.readMutex.Lock() + if len(p) < finalmask.UDPSize { + buf := make([]byte, finalmask.UDPSize) - n, addr, err = c.conn.ReadFrom(c.readBuf) - if err != nil { - c.readMutex.Unlock() + n, addr, err = c.PacketConn.ReadFrom(buf) + if err != nil || n == 0 { return n, addr, err } - if n < int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("aead").Base(io.ErrShortBuffer) - } - - if len(p) < n-int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("aead").Base(io.ErrShortBuffer) + if n < c.aead.NonceSize()+c.aead.Overhead() { + errors.LogDebug(context.Background(), addr, " mask read err aead short lenth ", n) + return 0, addr, nil } nonceSize := c.aead.NonceSize() - nonce := c.readBuf[:nonceSize] - ciphertext := c.readBuf[nonceSize:n] - _, err = c.aead.Open(p[:0], nonce, ciphertext, nil) + nonce := buf[:nonceSize] + ciphertext := buf[nonceSize:n] + plaintext, err := c.aead.Open(p[:0], nonce, ciphertext, nil) if err != nil { - c.readMutex.Unlock() - return 0, addr, errors.New("aead open").Base(err) + errors.LogDebug(context.Background(), addr, " mask read err aead open ", err) + return 0, addr, nil } - c.readMutex.Unlock() - return n - int(c.Size()), addr, nil + if len(plaintext) > len(p) { + errors.LogDebug(context.Background(), addr, " mask read err short buffer ", len(p), " ", len(plaintext)) + return 0, addr, nil + } + + return n - c.aead.NonceSize() - c.aead.Overhead(), addr, nil } - n, addr, err = c.conn.ReadFrom(p) - if err != nil { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil || n == 0 { return n, addr, err } - if n < int(c.Size()) { - return 0, addr, errors.New("aead").Base(io.ErrShortBuffer) + if n < c.aead.NonceSize()+c.aead.Overhead() { + errors.LogDebug(context.Background(), addr, " mask read err aead short lenth ", n) + return 0, addr, nil } nonceSize := c.aead.NonceSize() @@ -101,74 +79,40 @@ func (c *aes128gcmConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { ciphertext := p[nonceSize:n] _, err = c.aead.Open(ciphertext[:0], nonce, ciphertext, nil) if err != nil { - return 0, addr, errors.New("aead open").Base(err) + errors.LogDebug(context.Background(), addr, " mask read err aead open ", err) + return 0, addr, nil } + copy(p, p[nonceSize:n-c.aead.Overhead()]) - return n - int(c.Size()), addr, nil + return n - c.aead.NonceSize() - c.aead.Overhead(), addr, nil } func (c *aes128gcmConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.first { - if c.leaveSize+c.Size()+int32(len(p)) > 8192 { - return 0, errors.New("too many masks") - } - - c.writeMutex.Lock() - - n = copy(c.writeBuf[c.leaveSize+int32(c.aead.NonceSize()):], p) - // n = copy(c.writeBuf[c.leaveSize+c.Size():], p) - n += int(c.leaveSize) + int(c.Size()) - - nonceSize := c.aead.NonceSize() - nonce := c.writeBuf[c.leaveSize : c.leaveSize+int32(nonceSize)] - common.Must2(rand.Read(nonce)) - // copy(c.writeBuf[c.leaveSize+int32(nonceSize):], c.writeBuf[c.leaveSize+c.Size():n]) - plaintext := c.writeBuf[c.leaveSize+int32(nonceSize) : n-c.aead.Overhead()] - _ = c.aead.Seal(plaintext[:0], nonce, plaintext, nil) - - nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) - - if err != nil { - c.writeMutex.Unlock() - return 0, err - } - - if nn != n { - c.writeMutex.Unlock() - return 0, errors.New("nn != n") - } + if c.aead.NonceSize()+c.aead.Overhead()+len(p) > finalmask.UDPSize { + errors.LogDebug(context.Background(), addr, " mask write err short write ", c.aead.NonceSize()+c.aead.Overhead()+len(p), " ", finalmask.UDPSize) + return 0, nil + } - c.writeMutex.Unlock() - return len(p), nil + var buf []byte + if cap(p) != finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } else { + buf = p[:c.aead.NonceSize()+c.aead.Overhead()+len(p)] + copy(buf[c.aead.NonceSize():], p) + p = buf[c.aead.NonceSize() : c.aead.NonceSize()+len(p)] } nonceSize := c.aead.NonceSize() - nonce := p[c.leaveSize : c.leaveSize+int32(nonceSize)] + nonce := buf[:nonceSize] common.Must2(rand.Read(nonce)) - copy(p[c.leaveSize+int32(nonceSize):], p[c.leaveSize+c.Size():]) - plaintext := p[c.leaveSize+int32(nonceSize) : len(p)-c.aead.Overhead()] - _ = c.aead.Seal(plaintext[:0], nonce, plaintext, nil) - - return c.conn.WriteTo(p, addr) -} + ciphertext := buf[nonceSize : c.aead.NonceSize()+c.aead.Overhead()+len(p)] + _ = c.aead.Seal(ciphertext[:0], nonce, p, nil) -func (c *aes128gcmConn) Close() error { - return c.conn.Close() -} - -func (c *aes128gcmConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *aes128gcmConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *aes128gcmConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} + _, err = c.PacketConn.WriteTo(buf[:c.aead.NonceSize()+c.aead.Overhead()+len(p)], addr) + if err != nil { + return 0, err + } -func (c *aes128gcmConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return len(p), nil } diff --git a/transport/internet/finalmask/mkcp/original/config.go b/transport/internet/finalmask/mkcp/original/config.go index 026c979d625a..19db41383508 100644 --- a/transport/internet/finalmask/mkcp/original/config.go +++ b/transport/internet/finalmask/mkcp/original/config.go @@ -7,10 +7,10 @@ import ( func (c *Config) UDP() { } -func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClient(c, raw) } -func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnServer(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServer(c, raw) } diff --git a/transport/internet/finalmask/mkcp/original/conn.go b/transport/internet/finalmask/mkcp/original/conn.go index c97c8d7c8712..898541bdc649 100644 --- a/transport/internet/finalmask/mkcp/original/conn.go +++ b/transport/internet/finalmask/mkcp/original/conn.go @@ -1,16 +1,15 @@ package original import ( + "context" "crypto/cipher" "encoding/binary" "hash/fnv" - "io" "net" - sync "sync" - "time" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" ) type simple struct{} @@ -75,151 +74,77 @@ func (a *simple) Open(dst, nonce, cipherText, extra []byte) ([]byte, error) { } type simpleConn struct { - first bool - leaveSize int32 - - conn net.PacketConn + net.PacketConn aead cipher.AEAD - - readBuf []byte - readMutex sync.Mutex - writeBuf []byte - writeMutex sync.Mutex } -func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { +func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { conn := &simpleConn{ - first: first, - leaveSize: leaveSize, - - conn: raw, - aead: &simple{}, - } - - if first { - conn.readBuf = make([]byte, 8192) - conn.writeBuf = make([]byte, 8192) + PacketConn: raw, + aead: &simple{}, } return conn, nil } -func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) -} - -func (c *simpleConn) Size() int32 { - return int32(c.aead.NonceSize()) + int32(c.aead.Overhead()) +func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { + return NewConnClient(c, raw) } func (c *simpleConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if c.first { - c.readMutex.Lock() - - n, addr, err = c.conn.ReadFrom(c.readBuf) - if err != nil { - c.readMutex.Unlock() - return n, addr, err - } - - if n < int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("aead").Base(io.ErrShortBuffer) - } - - if len(p) < n-int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("aead").Base(io.ErrShortBuffer) - } - - ciphertext := c.readBuf[:n] - opened, err := c.aead.Open(nil, nil, ciphertext, nil) - if err != nil { - c.readMutex.Unlock() - return 0, addr, errors.New("aead open").Base(err) - } - - copy(p, opened) - - c.readMutex.Unlock() - return n - int(c.Size()), addr, nil + buf := p + if len(p) < finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) } - n, addr, err = c.conn.ReadFrom(p) - if err != nil { + n, addr, err = c.PacketConn.ReadFrom(buf) + if err != nil || n == 0 { return n, addr, err } - if n < int(c.Size()) { - return 0, addr, errors.New("aead").Base(io.ErrShortBuffer) + if n < c.aead.Overhead() { + errors.LogDebug(context.Background(), addr, " mask read err aead short lenth ", n) + return 0, addr, nil } - ciphertext := p[:n] + ciphertext := buf[:n] opened, err := c.aead.Open(nil, nil, ciphertext, nil) if err != nil { - c.readMutex.Unlock() - return 0, addr, errors.New("aead open").Base(err) + errors.LogDebug(context.Background(), addr, " mask read err aead open ", err) + return 0, addr, nil + } + + if len(opened) > len(p) { + errors.LogDebug(context.Background(), addr, " mask read err short buffer ", len(p), " ", len(opened)) + return 0, addr, nil } copy(p, opened) - return n - int(c.Size()), addr, nil + return n - c.aead.Overhead(), addr, nil } func (c *simpleConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.first { - if c.leaveSize+c.Size()+int32(len(p)) > 8192 { - return 0, errors.New("too many masks") - } - - c.writeMutex.Lock() - - n = copy(c.writeBuf[c.leaveSize+c.Size():], p) - n += int(c.leaveSize) + int(c.Size()) - - plaintext := c.writeBuf[c.leaveSize+c.Size() : n] - sealed := c.aead.Seal(nil, nil, plaintext, nil) - copy(c.writeBuf[c.leaveSize:], sealed) - - nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) - - if err != nil { - c.writeMutex.Unlock() - return 0, err - } - - if nn != n { - c.writeMutex.Unlock() - return 0, errors.New("nn != n") - } - - c.writeMutex.Unlock() - return len(p), nil + if c.aead.Overhead()+len(p) > finalmask.UDPSize { + errors.LogDebug(context.Background(), addr, " mask write err short write ", c.aead.Overhead()+len(p), " ", finalmask.UDPSize) + return 0, nil } - plaintext := p[c.leaveSize+c.Size():] - sealed := c.aead.Seal(nil, nil, plaintext, nil) - copy(p[c.leaveSize:], sealed) - - return c.conn.WriteTo(p, addr) -} - -func (c *simpleConn) Close() error { - return c.conn.Close() -} - -func (c *simpleConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} + var buf []byte + if cap(p) != finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } else { + buf = p[:c.aead.Overhead()+len(p)] + copy(buf[c.aead.Overhead():], p) + p = buf[c.aead.Overhead() : c.aead.Overhead()+len(p)] + } -func (c *simpleConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} + _ = c.aead.Seal(buf[:0], nil, p, nil) -func (c *simpleConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} + _, err = c.PacketConn.WriteTo(buf[:c.aead.Overhead()+len(p)], addr) + if err != nil { + return 0, err + } -func (c *simpleConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return len(p), nil } diff --git a/transport/internet/finalmask/mkcp/original/simple_test.go b/transport/internet/finalmask/mkcp/original/simple_test.go index f7db54748df8..be9d68398f37 100644 --- a/transport/internet/finalmask/mkcp/original/simple_test.go +++ b/transport/internet/finalmask/mkcp/original/simple_test.go @@ -8,6 +8,22 @@ import ( "github.com/xtls/xray-core/transport/internet/finalmask/mkcp/original" ) +func TestSimpleSealInPlace(t *testing.T) { + aead := original.NewSimple() + + text := []byte("0123456789012") + buf := make([]byte, 8192) + + copy(buf[aead.Overhead():], text) + plaintext := buf[aead.Overhead() : aead.Overhead()+len(text)] + + sealed := aead.Seal(nil, nil, plaintext, nil) + + _ = aead.Seal(buf[:0], nil, plaintext, nil) + + assert.Equal(t, sealed, buf[:aead.Overhead()+len(text)]) +} + func TestOriginalBounce(t *testing.T) { aead := original.NewSimple() buf := make([]byte, aead.NonceSize()+aead.Overhead()) diff --git a/transport/internet/finalmask/noise/config.go b/transport/internet/finalmask/noise/config.go new file mode 100644 index 000000000000..1764c0b19b19 --- /dev/null +++ b/transport/internet/finalmask/noise/config.go @@ -0,0 +1,14 @@ +package noise + +import "net" + +func (c *Config) UDP() { +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClient(c, raw) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServer(c, raw) +} diff --git a/transport/internet/finalmask/noise/config.pb.go b/transport/internet/finalmask/noise/config.pb.go new file mode 100644 index 000000000000..16e1b25328fd --- /dev/null +++ b/transport/internet/finalmask/noise/config.pb.go @@ -0,0 +1,225 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.33.5 +// source: transport/internet/finalmask/noise/config.proto + +package noise + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Item struct { + state protoimpl.MessageState `protogen:"open.v1"` + RandMin int64 `protobuf:"varint,1,opt,name=rand_min,json=randMin,proto3" json:"rand_min,omitempty"` + RandMax int64 `protobuf:"varint,2,opt,name=rand_max,json=randMax,proto3" json:"rand_max,omitempty"` + Packet []byte `protobuf:"bytes,3,opt,name=packet,proto3" json:"packet,omitempty"` + DelayMin int64 `protobuf:"varint,4,opt,name=delay_min,json=delayMin,proto3" json:"delay_min,omitempty"` + DelayMax int64 `protobuf:"varint,5,opt,name=delay_max,json=delayMax,proto3" json:"delay_max,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Item) Reset() { + *x = Item{} + mi := &file_transport_internet_finalmask_noise_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Item) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Item) ProtoMessage() {} + +func (x *Item) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_noise_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Item.ProtoReflect.Descriptor instead. +func (*Item) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_noise_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Item) GetRandMin() int64 { + if x != nil { + return x.RandMin + } + return 0 +} + +func (x *Item) GetRandMax() int64 { + if x != nil { + return x.RandMax + } + return 0 +} + +func (x *Item) GetPacket() []byte { + if x != nil { + return x.Packet + } + return nil +} + +func (x *Item) GetDelayMin() int64 { + if x != nil { + return x.DelayMin + } + return 0 +} + +func (x *Item) GetDelayMax() int64 { + if x != nil { + return x.DelayMax + } + return 0 +} + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + ResetMin int64 `protobuf:"varint,1,opt,name=reset_min,json=resetMin,proto3" json:"reset_min,omitempty"` + ResetMax int64 `protobuf:"varint,2,opt,name=reset_max,json=resetMax,proto3" json:"reset_max,omitempty"` + Items []*Item `protobuf:"bytes,3,rep,name=items,proto3" json:"items,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_noise_config_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_noise_config_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_noise_config_proto_rawDescGZIP(), []int{1} +} + +func (x *Config) GetResetMin() int64 { + if x != nil { + return x.ResetMin + } + return 0 +} + +func (x *Config) GetResetMax() int64 { + if x != nil { + return x.ResetMax + } + return 0 +} + +func (x *Config) GetItems() []*Item { + if x != nil { + return x.Items + } + return nil +} + +var File_transport_internet_finalmask_noise_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_noise_config_proto_rawDesc = "" + + "\n" + + "/transport/internet/finalmask/noise/config.proto\x12'xray.transport.internet.finalmask.noise\"\x8e\x01\n" + + "\x04Item\x12\x19\n" + + "\brand_min\x18\x01 \x01(\x03R\arandMin\x12\x19\n" + + "\brand_max\x18\x02 \x01(\x03R\arandMax\x12\x16\n" + + "\x06packet\x18\x03 \x01(\fR\x06packet\x12\x1b\n" + + "\tdelay_min\x18\x04 \x01(\x03R\bdelayMin\x12\x1b\n" + + "\tdelay_max\x18\x05 \x01(\x03R\bdelayMax\"\x87\x01\n" + + "\x06Config\x12\x1b\n" + + "\treset_min\x18\x01 \x01(\x03R\bresetMin\x12\x1b\n" + + "\treset_max\x18\x02 \x01(\x03R\bresetMax\x12C\n" + + "\x05items\x18\x03 \x03(\v2-.xray.transport.internet.finalmask.noise.ItemR\x05itemsB\x97\x01\n" + + "+com.xray.transport.internet.finalmask.noiseP\x01Z xray.transport.internet.finalmask.noise.Item + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_noise_config_proto_init() } +func file_transport_internet_finalmask_noise_config_proto_init() { + if File_transport_internet_finalmask_noise_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_noise_config_proto_rawDesc), len(file_transport_internet_finalmask_noise_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_noise_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_noise_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_noise_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_noise_config_proto = out.File + file_transport_internet_finalmask_noise_config_proto_goTypes = nil + file_transport_internet_finalmask_noise_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/noise/config.proto b/transport/internet/finalmask/noise/config.proto new file mode 100644 index 000000000000..552603faf2ff --- /dev/null +++ b/transport/internet/finalmask/noise/config.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.noise; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Noise"; +option go_package = "github.com/xtls/xray-core/transport/internet/finalmask/noise"; +option java_package = "com.xray.transport.internet.finalmask.noise"; +option java_multiple_files = true; + +message Item { + int64 rand_min = 1; + int64 rand_max = 2; + bytes packet = 3; + int64 delay_min = 4; + int64 delay_max = 5; +} + +message Config { + int64 reset_min = 1; + int64 reset_max = 2; + repeated Item items = 3; +} diff --git a/transport/internet/finalmask/noise/conn.go b/transport/internet/finalmask/noise/conn.go new file mode 100644 index 000000000000..022fb21bcb62 --- /dev/null +++ b/transport/internet/finalmask/noise/conn.go @@ -0,0 +1,98 @@ +package noise + +import ( + "crypto/rand" + "net" + "sync" + "time" + + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/crypto" +) + +type noiseConn struct { + net.PacketConn + config *Config + m map[string]time.Time + stop chan struct{} + once sync.Once + mutex sync.RWMutex +} + +func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { + conn := &noiseConn{ + PacketConn: raw, + config: c, + m: make(map[string]time.Time), + stop: make(chan struct{}), + } + + if conn.config.ResetMax > 0 { + go conn.reset() + } + + return conn, nil +} + +func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { + return NewConnClient(c, raw) +} + +func (c *noiseConn) reset() { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + c.mutex.RLock() + now := time.Now() + timeOut := make([]string, 0, len(c.m)) + for key, last := range c.m { + if now.After(last) { + timeOut = append(timeOut, key) + } + } + c.mutex.RUnlock() + + for _, key := range timeOut { + c.mutex.Lock() + delete(c.m, key) + c.mutex.Unlock() + } + case <-c.stop: + return + } + } +} + +func (c *noiseConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + c.mutex.RLock() + _, ready := c.m[addr.String()] + c.mutex.RUnlock() + + if !ready { + c.mutex.Lock() + _, ready = c.m[addr.String()] + if !ready { + for _, item := range c.config.Items { + if item.RandMax > 0 { + item.Packet = make([]byte, crypto.RandBetween(item.RandMin, item.RandMax)) + common.Must2(rand.Read(item.Packet)) + } + c.PacketConn.WriteTo(item.Packet, addr) + time.Sleep(time.Duration(crypto.RandBetween(item.DelayMin, item.DelayMax)) * time.Millisecond) + } + c.m[addr.String()] = time.Now().Add(time.Duration(crypto.RandBetween(c.config.ResetMin, c.config.ResetMax)) * time.Second) + } + c.mutex.Unlock() + } + + return c.PacketConn.WriteTo(p, addr) +} + +func (c *noiseConn) Close() error { + c.once.Do(func() { + close(c.stop) + }) + return c.PacketConn.Close() +} diff --git a/transport/internet/finalmask/salamander/config.go b/transport/internet/finalmask/salamander/config.go index c864e270d666..371b528c61c1 100644 --- a/transport/internet/finalmask/salamander/config.go +++ b/transport/internet/finalmask/salamander/config.go @@ -7,10 +7,10 @@ import ( func (c *Config) UDP() { } -func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClient(c, raw) } -func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnServer(c, raw, first, leaveSize) +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServer(c, raw) } diff --git a/transport/internet/finalmask/salamander/conn.go b/transport/internet/finalmask/salamander/conn.go index 154a6aa33312..f693ef07ad62 100644 --- a/transport/internet/finalmask/salamander/conn.go +++ b/transport/internet/finalmask/salamander/conn.go @@ -1,147 +1,83 @@ package salamander import ( - "io" + "context" "net" - "sync" - "time" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" ) -type obfsPacketConn struct { - first bool - leaveSize int32 - - conn net.PacketConn +type salamanderConn struct { + net.PacketConn obfs *SalamanderObfuscator - - readBuf []byte - readMutex sync.Mutex - writeBuf []byte - writeMutex sync.Mutex } -func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { +func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { ob, err := NewSalamanderObfuscator([]byte(c.Password)) if err != nil { return nil, errors.New("salamander err").Base(err) } - conn := &obfsPacketConn{ - first: first, - leaveSize: leaveSize, - - conn: raw, - obfs: ob, - } - - if first { - conn.readBuf = make([]byte, 8192) - conn.writeBuf = make([]byte, 8192) + conn := &salamanderConn{ + PacketConn: raw, + obfs: ob, } return conn, nil } -func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { - return NewConnClient(c, raw, first, leaveSize) -} - -func (c *obfsPacketConn) Size() int32 { - return smSaltLen +func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { + return NewConnClient(c, raw) } -func (c *obfsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if c.first { - c.readMutex.Lock() - - n, addr, err = c.conn.ReadFrom(c.readBuf) - if err != nil { - c.readMutex.Unlock() - return n, addr, err - } - - if n < int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("salamander").Base(io.ErrShortBuffer) - } - - if len(p) < n-int(c.Size()) { - c.readMutex.Unlock() - return 0, addr, errors.New("salamander").Base(io.ErrShortBuffer) - } - - c.obfs.Deobfuscate(c.readBuf[:n], p) - - c.readMutex.Unlock() - return n - int(c.Size()), addr, err +func (c *salamanderConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + buf := p + if len(p) < finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) } - n, addr, err = c.conn.ReadFrom(p) - if err != nil { + n, addr, err = c.PacketConn.ReadFrom(buf) + if err != nil || n == 0 { return n, addr, err } - if n < int(c.Size()) { - return 0, addr, errors.New("salamander").Base(io.ErrShortBuffer) + if n < smSaltLen { + errors.LogDebug(context.Background(), addr, " mask read err short lenth ", n) + return 0, addr, nil } - c.obfs.Deobfuscate(p[:n], p) - - return n - int(c.Size()), addr, err -} - -func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.first { - if c.leaveSize+c.Size()+int32(len(p)) > 8192 { - return 0, errors.New("too many masks") - } - - c.writeMutex.Lock() - - n = copy(c.writeBuf[c.leaveSize+c.Size():], p) - n += int(c.leaveSize) + int(c.Size()) - - c.obfs.Obfuscate(c.writeBuf[c.leaveSize+c.Size():n], c.writeBuf[c.leaveSize:n]) - - nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) - - if err != nil { - c.writeMutex.Unlock() - return 0, err - } - - if nn != n { - c.writeMutex.Unlock() - return 0, errors.New("nn != n") - } - - c.writeMutex.Unlock() - return len(p), nil + if len(p) < n-smSaltLen { + errors.LogDebug(context.Background(), addr, " mask read err short buffer ", len(p), " ", n-smSaltLen) + return 0, addr, nil } - c.obfs.Obfuscate(p[c.leaveSize+c.Size():], p[c.leaveSize:]) + c.obfs.Deobfuscate(buf[:n], p) - return c.conn.WriteTo(p, addr) + return n - smSaltLen, addr, nil } -func (c *obfsPacketConn) Close() error { - return c.conn.Close() -} +func (c *salamanderConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if smSaltLen+len(p) > finalmask.UDPSize { + errors.LogDebug(context.Background(), addr, " mask write err short write ", smSaltLen+len(p), " ", finalmask.UDPSize) + return 0, nil + } -func (c *obfsPacketConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} + var buf []byte + if cap(p) != finalmask.UDPSize { + buf = make([]byte, finalmask.UDPSize) + } else { + buf = p[:smSaltLen+len(p)] + copy(buf[smSaltLen:], p) + p = buf[smSaltLen:] + } -func (c *obfsPacketConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} + c.obfs.Obfuscate(p, buf) -func (c *obfsPacketConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} + _, err = c.PacketConn.WriteTo(buf[:smSaltLen+len(p)], addr) + if err != nil { + return 0, err + } -func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return len(p), nil } diff --git a/transport/internet/finalmask/tcp_test.go b/transport/internet/finalmask/tcp_test.go new file mode 100644 index 000000000000..7febb185695d --- /dev/null +++ b/transport/internet/finalmask/tcp_test.go @@ -0,0 +1,123 @@ +package finalmask_test + +import ( + "bytes" + "io" + "net" + "testing" + "time" + + "github.com/xtls/xray-core/transport/internet/finalmask" + "github.com/xtls/xray-core/transport/internet/finalmask/header/custom" +) + +func mustSendRecvTcp( + t *testing.T, + from net.Conn, + to net.Conn, + msg []byte, +) { + t.Helper() + + go func() { + _, err := from.Write(msg) + if err != nil { + t.Error(err) + } + }() + + buf := make([]byte, 1024) + n, err := io.ReadFull(to, buf[:len(msg)]) + if err != nil { + t.Fatal(err) + } + + if n != len(msg) { + t.Fatalf("unexpected size: %d", n) + } + + if !bytes.Equal(buf[:n], msg) { + t.Fatalf("unexpected data %q", buf[:n]) + } +} + +type layerMaskTcp struct { + name string + mask finalmask.Tcpmask +} + +func TestConnReadWrite(t *testing.T) { + cases := []layerMaskTcp{ + { + name: "custom", + mask: &custom.TCPConfig{ + Clients: []*custom.TCPSequence{ + { + Sequence: []*custom.TCPItem{ + { + Packet: []byte{1}, + }, + { + Rand: 1, + }, + }, + }, + }, + Servers: []*custom.TCPSequence{ + { + Sequence: []*custom.TCPItem{ + { + Packet: []byte{2}, + }, + { + Rand: 1, + }, + }, + }, + }, + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + mask := c.mask + + maskManager := finalmask.NewTcpmaskManager([]finalmask.Tcpmask{mask}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + client, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + client, err = maskManager.WrapConnClient(client) + if err != nil { + t.Fatal(err) + } + + server, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + + server, err = maskManager.WrapConnServer(server) + if err != nil { + t.Fatal(err) + } + + _ = client.SetDeadline(time.Now().Add(time.Second)) + _ = server.SetDeadline(time.Now().Add(time.Second)) + + mustSendRecvTcp(t, client, server, []byte("client -> server")) + mustSendRecvTcp(t, server, client, []byte("server -> client")) + + mustSendRecvTcp(t, client, server, []byte{}) + mustSendRecvTcp(t, server, client, []byte{}) + }) + } +} diff --git a/transport/internet/finalmask/udp_test.go b/transport/internet/finalmask/udp_test.go index bc4962ff6437..49cdd9233c3c 100644 --- a/transport/internet/finalmask/udp_test.go +++ b/transport/internet/finalmask/udp_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/xtls/xray-core/transport/internet/finalmask" + "github.com/xtls/xray-core/transport/internet/finalmask/header/custom" "github.com/xtls/xray-core/transport/internet/finalmask/header/dns" "github.com/xtls/xray-core/transport/internet/finalmask/header/srtp" "github.com/xtls/xray-core/transport/internet/finalmask/header/utp" @@ -82,6 +83,27 @@ func TestPacketConnReadWrite(t *testing.T) { name: "wireguard", mask: &wireguard.Config{}, }, + { + name: "custom", + mask: &custom.UDPConfig{ + Client: []*custom.UDPItem{ + { + Packet: []byte{1}, + }, + { + Rand: 1, + }, + }, + Server: []*custom.UDPItem{ + { + Packet: []byte{1}, + }, + { + Rand: 1, + }, + }, + }, + }, { name: "salamander", mask: &salamander.Config{Password: "1234"}, @@ -98,7 +120,6 @@ func TestPacketConnReadWrite(t *testing.T) { if err != nil { t.Fatal(err) } - defer client.Close() client, err = maskManager.WrapPacketConnClient(client) if err != nil { @@ -109,7 +130,6 @@ func TestPacketConnReadWrite(t *testing.T) { if err != nil { t.Fatal(err) } - defer server.Close() server, err = maskManager.WrapPacketConnServer(server) if err != nil { diff --git a/transport/internet/finalmask/xdns/client.go b/transport/internet/finalmask/xdns/client.go index 9d80bc225762..81b1366269bc 100644 --- a/transport/internet/finalmask/xdns/client.go +++ b/transport/internet/finalmask/xdns/client.go @@ -6,12 +6,15 @@ import ( "crypto/rand" "encoding/base32" "encoding/binary" + go_errors "errors" "io" "net" "sync" "time" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" ) const ( @@ -31,7 +34,7 @@ type packet struct { } type xdnsConnClient struct { - conn net.PacketConn + net.PacketConn clientID []byte domain Name @@ -44,28 +47,24 @@ type xdnsConnClient struct { mutex sync.Mutex } -func NewConnClient(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) { - if !end { - return nil, errors.New("xdns requires being at the outermost level") - } - +func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { domain, err := ParseName(c.Domain) if err != nil { return nil, err } conn := &xdnsConnClient{ - conn: raw, + PacketConn: raw, clientID: make([]byte, 8), domain: domain, pollChan: make(chan struct{}, pollLimit), readQueue: make(chan *packet, 128), - writeQueue: make(chan *packet, 128), + writeQueue: make(chan *packet, 256), } - rand.Read(conn.clientID) + common.Must2(rand.Read(conn.clientID)) go conn.recvLoop() go conn.sendLoop() @@ -74,20 +73,24 @@ func NewConnClient(c *Config, raw net.PacketConn, end bool) (net.PacketConn, err } func (c *xdnsConnClient) recvLoop() { + var buf [finalmask.UDPSize]byte + for { if c.closed { break } - var buf [4096]byte - - n, addr, err := c.conn.ReadFrom(buf[:]) - if err != nil { + n, addr, err := c.PacketConn.ReadFrom(buf[:]) + if err != nil || n == 0 { + if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.EOF) { + break + } continue } resp, err := MessageFromWireFormat(buf[:n]) if err != nil { + errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err) continue } @@ -110,6 +113,7 @@ func (c *xdnsConnClient) recvLoop() { addr: addr, }: default: + errors.LogDebug(context.Background(), addr, " mask read err queue full") } } @@ -121,8 +125,16 @@ func (c *xdnsConnClient) recvLoop() { } } + errors.LogDebug(context.Background(), "xdns closed") + close(c.pollChan) close(c.readQueue) + + c.mutex.Lock() + defer c.mutex.Unlock() + + c.closed = true + close(c.writeQueue) } func (c *xdnsConnClient) sendLoop() { @@ -178,25 +190,26 @@ func (c *xdnsConnClient) sendLoop() { } if p != nil { - _, _ = c.conn.WriteTo(p.p, p.addr) + _, err := c.PacketConn.WriteTo(p.p, p.addr) + if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.ErrClosedPipe) { + c.closed = true + break + } } } } -func (c *xdnsConnClient) Size() int32 { - return 0 -} - func (c *xdnsConnClient) ReadFrom(p []byte) (n int, addr net.Addr, err error) { packet, ok := <-c.readQueue if !ok { - return 0, nil, io.EOF + return 0, nil, net.ErrClosed } - n = copy(p, packet.p) - if n != len(packet.p) { - return 0, nil, io.ErrShortBuffer + if len(p) < len(packet.p) { + errors.LogDebug(context.Background(), packet.addr, " mask read err short buffer ", len(p), " ", len(packet.p)) + return 0, packet.addr, nil } - return n, packet.addr, nil + copy(p, packet.p) + return len(packet.p), packet.addr, nil } func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { @@ -204,13 +217,13 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { defer c.mutex.Unlock() if c.closed { - return 0, errors.New("xdns closed") + return 0, io.ErrClosedPipe } encoded, err := encode(p, c.clientID, c.domain) if err != nil { - errors.LogDebug(context.Background(), "xdns encode err ", err) - return 0, errors.New("xdns encode").Base(err) + errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p)) + return 0, nil } select { @@ -220,38 +233,14 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { }: return len(p), nil default: - return 0, errors.New("xdns queue full") + errors.LogDebug(context.Background(), addr, " mask write err queue full") + return 0, nil } } func (c *xdnsConnClient) Close() error { - c.mutex.Lock() - defer c.mutex.Unlock() - - if c.closed { - return nil - } - c.closed = true - close(c.writeQueue) - - return c.conn.Close() -} - -func (c *xdnsConnClient) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *xdnsConnClient) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *xdnsConnClient) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *xdnsConnClient) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return c.PacketConn.Close() } func encode(p []byte, clientID []byte, domain Name) ([]byte, error) { diff --git a/transport/internet/finalmask/xdns/config.go b/transport/internet/finalmask/xdns/config.go index cf30902aee4d..157102dafa2b 100644 --- a/transport/internet/finalmask/xdns/config.go +++ b/transport/internet/finalmask/xdns/config.go @@ -7,10 +7,10 @@ import ( func (c *Config) UDP() { } -func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnClient(c, raw, end) +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClient(c, raw) } -func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnServer(c, raw, end) +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServer(c, raw) } diff --git a/transport/internet/finalmask/xdns/dns_test.go b/transport/internet/finalmask/xdns/dns_test.go index b07f57b9758c..aa163476d9f1 100644 --- a/transport/internet/finalmask/xdns/dns_test.go +++ b/transport/internet/finalmask/xdns/dns_test.go @@ -557,6 +557,8 @@ func TestEncodeRDataTXT(t *testing.T) { if len(encoded) > 256 { t.Errorf("EncodeRDataTXT(%d bytes) returned %d bytes", len(p), len(encoded)) } + + fmt.Println(EncodeRDataTXT(nil)) } func TestRDataTXTRoundTrip(t *testing.T) { diff --git a/transport/internet/finalmask/xdns/server.go b/transport/internet/finalmask/xdns/server.go index 2a5ec6cb8fe3..00afe57200fc 100644 --- a/transport/internet/finalmask/xdns/server.go +++ b/transport/internet/finalmask/xdns/server.go @@ -4,16 +4,18 @@ import ( "bytes" "context" "encoding/binary" + go_errors "errors" "io" "net" "sync" "time" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" ) const ( - idleTimeout = 2 * time.Minute + idleTimeout = 10 * time.Second responseTTL = 60 maxResponseDelay = 1 * time.Second ) @@ -42,13 +44,13 @@ type record struct { } type queue struct { - lash time.Time + last time.Time queue chan []byte stash chan []byte } type xdnsConnServer struct { - conn net.PacketConn + net.PacketConn domain Name @@ -60,23 +62,19 @@ type xdnsConnServer struct { mutex sync.Mutex } -func NewConnServer(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) { - if !end { - return nil, errors.New("xdns requires being at the outermost level") - } - +func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { domain, err := ParseName(c.Domain) if err != nil { return nil, err } conn := &xdnsConnServer{ - conn: raw, + PacketConn: raw, domain: domain, - ch: make(chan *record, 100), - readQueue: make(chan *packet, 128), + ch: make(chan *record, 500), + readQueue: make(chan *packet, 256), writeQueueMap: make(map[string]*queue), } @@ -99,7 +97,7 @@ func (c *xdnsConnServer) clean() { now := time.Now() for key, q := range c.writeQueueMap { - if now.Sub(q.lash) >= idleTimeout { + if now.Sub(q.last) >= idleTimeout { close(q.queue) close(q.stash) delete(c.writeQueueMap, key) @@ -118,9 +116,6 @@ func (c *xdnsConnServer) clean() { } func (c *xdnsConnServer) ensureQueue(addr net.Addr) *queue { - c.mutex.Lock() - defer c.mutex.Unlock() - if c.closed { return nil } @@ -128,12 +123,12 @@ func (c *xdnsConnServer) ensureQueue(addr net.Addr) *queue { q, ok := c.writeQueueMap[addr.String()] if !ok { q = &queue{ - queue: make(chan []byte, 128), + queue: make(chan []byte, 512), stash: make(chan []byte, 1), } c.writeQueueMap[addr.String()] = q } - q.lash = time.Now() + q.last = time.Now() return q } @@ -153,19 +148,24 @@ func (c *xdnsConnServer) stash(queue *queue, p []byte) { } func (c *xdnsConnServer) recvLoop() { + var buf [finalmask.UDPSize]byte + for { if c.closed { break } - var buf [4096]byte - n, addr, err := c.conn.ReadFrom(buf[:]) - if err != nil { + n, addr, err := c.PacketConn.ReadFrom(buf[:]) + if err != nil || n == 0 { + if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.EOF) { + break + } continue } query, err := MessageFromWireFormat(buf[:n]) if err != nil { + errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err) continue } @@ -190,6 +190,7 @@ func (c *xdnsConnServer) recvLoop() { addr: clientIDToAddr(clientID), }: default: + errors.LogDebug(context.Background(), addr, " ", clientID, " mask read err queue full") } } } else { @@ -202,12 +203,25 @@ func (c *xdnsConnServer) recvLoop() { select { case c.ch <- &record{resp, addr, clientIDToAddr(clientID)}: default: + errors.LogDebug(context.Background(), addr, " ", clientID, " mask read err record queue full") } } } + errors.LogDebug(context.Background(), "xdns closed") + close(c.ch) close(c.readQueue) + + c.mutex.Lock() + defer c.mutex.Unlock() + + c.closed = true + for key, q := range c.writeQueueMap { + close(q.queue) + close(q.stash) + delete(c.writeQueueMap, key) + } } func (c *xdnsConnServer) sendLoop() { @@ -238,24 +252,28 @@ func (c *xdnsConnServer) sendLoop() { var payload bytes.Buffer limit := maxEncodedPayload timer := time.NewTimer(maxResponseDelay) + for { - queue := c.ensureQueue(rec.ClientAddr) - if queue == nil { + c.mutex.Lock() + q := c.ensureQueue(rec.ClientAddr) + if q == nil { + c.mutex.Unlock() return } + c.mutex.Unlock() var p []byte select { - case p = <-queue.stash: + case p = <-q.stash: default: select { - case p = <-queue.stash: - case p = <-queue.queue: + case p = <-q.stash: + case p = <-q.queue: default: select { - case p = <-queue.stash: - case p = <-queue.queue: + case p = <-q.stash: + case p = <-q.queue: case <-timer.C: case nextRec = <-c.ch: } @@ -269,33 +287,31 @@ func (c *xdnsConnServer) sendLoop() { } limit -= 2 + len(p) - if payload.Len() == 0 { - - } else if limit < 0 { - c.stash(queue, p) - + if payload.Len() > 0 && limit < 0 { + c.stash(q, p) break } - if int(uint16(len(p))) != len(p) { - panic(len(p)) - } + // if len(p) > 65535 { + // panic(len(p)) + // } _ = binary.Write(&payload, binary.BigEndian, uint16(len(p))) payload.Write(p) } - timer.Stop() + timer.Stop() rec.Resp.Answer[0].Data = EncodeRDataTXT(payload.Bytes()) } buf, err := rec.Resp.WireFormat() if err != nil { + errors.LogDebug(context.Background(), rec.Addr, " ", rec.ClientAddr, " xdns wireformat err ", err) continue } if len(buf) > maxUDPPayload { - errors.LogDebug(context.Background(), "xdns server truncate ", len(buf)) + errors.LogDebug(context.Background(), rec.Addr, " ", rec.ClientAddr, " xdns truncate ", len(buf)) buf = buf[:maxUDPPayload] buf[2] |= 0x02 } @@ -304,37 +320,39 @@ func (c *xdnsConnServer) sendLoop() { return } - _, _ = c.conn.WriteTo(buf, rec.Addr) + _, err = c.PacketConn.WriteTo(buf, rec.Addr) + if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.ErrClosedPipe) { + c.closed = true + break + } } } -func (c *xdnsConnServer) Size() int32 { - return 0 -} - func (c *xdnsConnServer) ReadFrom(p []byte) (n int, addr net.Addr, err error) { packet, ok := <-c.readQueue if !ok { - return 0, nil, io.EOF + return 0, nil, net.ErrClosed } - n = copy(p, packet.p) - if n != len(packet.p) { - return 0, nil, io.ErrShortBuffer + if len(p) < len(packet.p) { + errors.LogDebug(context.Background(), packet.addr, " mask read err short buffer ", len(p), " ", len(packet.p)) + return 0, packet.addr, nil } - return n, packet.addr, nil + copy(p, packet.p) + return len(packet.p), packet.addr, nil } func (c *xdnsConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) { - q := c.ensureQueue(addr) - if q == nil { - return 0, errors.New("xdns closed") + if len(p)+2 > maxEncodedPayload { + errors.LogDebug(context.Background(), addr, " mask write err short write ", len(p), "+2 > ", maxEncodedPayload) + return 0, nil } c.mutex.Lock() defer c.mutex.Unlock() - if c.closed { - return 0, errors.New("xdns closed") + q := c.ensureQueue(addr) + if q == nil { + return 0, io.ErrClosedPipe } buf := make([]byte, len(p)) @@ -344,42 +362,14 @@ func (c *xdnsConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) { case q.queue <- buf: return len(p), nil default: - return 0, errors.New("xdns queue full") + // errors.LogDebug(context.Background(), addr, " mask write err queue full") + return 0, nil } } func (c *xdnsConnServer) Close() error { - c.mutex.Lock() - defer c.mutex.Unlock() - - if c.closed { - return nil - } - c.closed = true - for key, q := range c.writeQueueMap { - close(q.queue) - close(q.stash) - delete(c.writeQueueMap, key) - } - - return c.conn.Close() -} - -func (c *xdnsConnServer) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *xdnsConnServer) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *xdnsConnServer) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *xdnsConnServer) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return c.PacketConn.Close() } func nextPacketServer(r *bytes.Reader) ([]byte, error) { diff --git a/transport/internet/finalmask/xicmp/client.go b/transport/internet/finalmask/xicmp/client.go index 5530861bb4ed..0c900dc3f05a 100644 --- a/transport/internet/finalmask/xicmp/client.go +++ b/transport/internet/finalmask/xicmp/client.go @@ -10,6 +10,9 @@ import ( "github.com/xtls/xray-core/common/crypto" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/finalmask" + "github.com/xtls/xray-core/transport/internet/hysteria/udphop" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -51,8 +54,10 @@ type xicmpConnClient struct { mutex sync.Mutex } -func NewConnClient(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) { - if !end { +func NewConnClient(c *Config, raw net.PacketConn, level int) (net.PacketConn, error) { + _, ok1 := raw.(*internet.FakePacketConn) + _, ok2 := raw.(*udphop.UdpHopPacketConn) + if level != 0 || ok1 || ok2 { return nil, errors.New("xicmp requires being at the outermost level") } @@ -86,7 +91,7 @@ func NewConnClient(c *Config, raw net.PacketConn, end bool) (net.PacketConn, err pollChan: make(chan struct{}, pollLimit), readQueue: make(chan *packet, 128), - writeQueue: make(chan *packet, 128), + writeQueue: make(chan *packet, 256), } go conn.recvLoop() @@ -122,8 +127,8 @@ func (c *xicmpConnClient) encode(p []byte) ([]byte, error) { return nil, err } - if len(buf) > 8192 { - return nil, errors.New("xicmp len(buf) > 8192") + if len(buf) > finalmask.UDPSize { + return nil, errors.New("xicmp len(buf) > finalmask.UDPSize") } c.seqStatus[c.seq] = &seqStatus{ @@ -144,13 +149,13 @@ func (c *xicmpConnClient) encode(p []byte) ([]byte, error) { } func (c *xicmpConnClient) recvLoop() { + var buf [finalmask.UDPSize]byte + for { if c.closed { break } - var buf [8192]byte - n, addr, err := c.icmpConn.ReadFrom(buf[:]) if err != nil { continue @@ -201,6 +206,7 @@ func (c *xicmpConnClient) recvLoop() { addr: &net.UDPAddr{IP: addr.(*net.IPAddr).IP}, }: default: + errors.LogDebug(context.Background(), addr, " ", echo.Seq, " ", echo.ID, " mask read err queue full") } select { @@ -210,8 +216,16 @@ func (c *xicmpConnClient) recvLoop() { } } + errors.LogDebug(context.Background(), "xicmp closed") + close(c.pollChan) close(c.readQueue) + + c.mutex.Lock() + defer c.mutex.Unlock() + + c.closed = true + close(c.writeQueue) } func (c *xicmpConnClient) sendLoop() { @@ -269,39 +283,37 @@ func (c *xicmpConnClient) sendLoop() { if p != nil { _, err := c.icmpConn.WriteTo(p.p, p.addr) if err != nil { - errors.LogDebug(context.Background(), "xicmp writeto err ", err) + errors.LogDebug(context.Background(), p.addr, " xicmp writeto err ", err) } } } } -func (c *xicmpConnClient) Size() int32 { - return 0 -} - func (c *xicmpConnClient) ReadFrom(p []byte) (n int, addr net.Addr, err error) { packet, ok := <-c.readQueue if !ok { - return 0, nil, io.EOF + return 0, nil, net.ErrClosed } - n = copy(p, packet.p) - if n != len(packet.p) { - return 0, nil, io.ErrShortBuffer + if len(p) < len(packet.p) { + errors.LogDebug(context.Background(), packet.addr, " mask read err short buffer ", len(p), " ", len(packet.p)) + return 0, packet.addr, nil } - return n, packet.addr, nil + copy(p, packet.p) + return len(packet.p), packet.addr, nil } func (c *xicmpConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { encoded, err := c.encode(p) if err != nil { - return 0, errors.New("xicmp encode").Base(err) + errors.LogDebug(context.Background(), addr, " xicmp wireformat err ", err) + return 0, nil } c.mutex.Lock() defer c.mutex.Unlock() if c.closed { - return 0, errors.New("xicmp closed") + return 0, io.ErrClosedPipe } select { @@ -311,21 +323,13 @@ func (c *xicmpConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { }: return len(p), nil default: - return 0, errors.New("xicmp queue full") + errors.LogDebug(context.Background(), addr, " mask write err queue full") + return 0, nil } } func (c *xicmpConnClient) Close() error { - c.mutex.Lock() - defer c.mutex.Unlock() - - if c.closed { - return nil - } - c.closed = true - close(c.writeQueue) - _ = c.icmpConn.Close() return c.conn.Close() } diff --git a/transport/internet/finalmask/xicmp/config.go b/transport/internet/finalmask/xicmp/config.go index 81a483af8539..c570ce96817e 100644 --- a/transport/internet/finalmask/xicmp/config.go +++ b/transport/internet/finalmask/xicmp/config.go @@ -7,10 +7,10 @@ import ( func (c *Config) UDP() { } -func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnClient(c, raw, end) +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnClient(c, raw, level) } -func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { - return NewConnServer(c, raw, end) +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewConnServer(c, raw, level) } diff --git a/transport/internet/finalmask/xicmp/server.go b/transport/internet/finalmask/xicmp/server.go index 8d2ee256ef74..78d3397eb58b 100644 --- a/transport/internet/finalmask/xicmp/server.go +++ b/transport/internet/finalmask/xicmp/server.go @@ -10,13 +10,14 @@ import ( "github.com/xtls/xray-core/common/crypto" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) const ( - idleTimeout = 2 * time.Minute + idleTimeout = 10 * time.Second maxResponseDelay = 1 * time.Second ) @@ -29,7 +30,7 @@ type record struct { } type queue struct { - lash time.Time + last time.Time queue chan []byte } @@ -49,8 +50,8 @@ type xicmpConnServer struct { mutex sync.Mutex } -func NewConnServer(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) { - if !end { +func NewConnServer(c *Config, raw net.PacketConn, level int) (net.PacketConn, error) { + if level != 0 { return nil, errors.New("xicmp requires being at the outermost level") } @@ -76,8 +77,8 @@ func NewConnServer(c *Config, raw net.PacketConn, end bool) (net.PacketConn, err proto: proto, config: c, - ch: make(chan *record, 100), - readQueue: make(chan *packet, 128), + ch: make(chan *record, 500), + readQueue: make(chan *packet, 256), writeQueueMap: make(map[string]*queue), } @@ -100,7 +101,7 @@ func (c *xicmpConnServer) clean() { now := time.Now() for key, q := range c.writeQueueMap { - if now.Sub(q.lash) >= idleTimeout { + if now.Sub(q.last) >= idleTimeout { close(q.queue) delete(c.writeQueueMap, key) } @@ -118,9 +119,6 @@ func (c *xicmpConnServer) clean() { } func (c *xicmpConnServer) ensureQueue(addr net.Addr) *queue { - c.mutex.Lock() - defer c.mutex.Unlock() - if c.closed { return nil } @@ -128,11 +126,11 @@ func (c *xicmpConnServer) ensureQueue(addr net.Addr) *queue { q, ok := c.writeQueueMap[addr.String()] if !ok { q = &queue{ - queue: make(chan []byte, 128), + queue: make(chan []byte, 512), } c.writeQueueMap[addr.String()] = q } - q.lash = time.Now() + q.last = time.Now() return q } @@ -159,8 +157,8 @@ func (c *xicmpConnServer) encode(p []byte, id int, seq int, needSeqByte bool, se return nil, err } - if len(buf) > 8192 { - return nil, errors.New("xicmp len(buf) > 8192") + if len(buf) > finalmask.UDPSize { + return nil, errors.New("xicmp len(buf) > finalmask.UDPSize") } return buf, nil @@ -177,13 +175,13 @@ func (c *xicmpConnServer) randUntil(b1 byte) byte { } func (c *xicmpConnServer) recvLoop() { + var buf [finalmask.UDPSize]byte + for { if c.closed { break } - var buf [8192]byte - n, addr, err := c.icmpConn.ReadFrom(buf[:]) if err != nil { continue @@ -225,6 +223,7 @@ func (c *xicmpConnServer) recvLoop() { }, }: default: + errors.LogDebug(context.Background(), addr, " ", echo.ID, " ", echo.Seq, " mask read err queue full") } } @@ -240,11 +239,23 @@ func (c *xicmpConnServer) recvLoop() { }, }: default: + errors.LogDebug(context.Background(), addr, " ", echo.ID, " ", echo.Seq, " mask read err record queue full") } } + errors.LogDebug(context.Background(), "xicmp closed") + close(c.ch) close(c.readQueue) + + c.mutex.Lock() + defer c.mutex.Unlock() + + c.closed = true + for key, q := range c.writeQueueMap { + close(q.queue) + delete(c.writeQueueMap, key) + } } func (c *xicmpConnServer) sendLoop() { @@ -261,20 +272,23 @@ func (c *xicmpConnServer) sendLoop() { } } - queue := c.ensureQueue(rec.addr) - if queue == nil { + c.mutex.Lock() + q := c.ensureQueue(rec.addr) + if q == nil { + c.mutex.Unlock() return } + c.mutex.Unlock() var p []byte timer := time.NewTimer(maxResponseDelay) select { - case p = <-queue.queue: + case p = <-q.queue: default: select { - case p = <-queue.queue: + case p = <-q.queue: case <-timer.C: case nextRec = <-c.ch: } @@ -288,6 +302,7 @@ func (c *xicmpConnServer) sendLoop() { buf, err := c.encode(p, rec.id, rec.seq, rec.needSeqByte, rec.seqByte) if err != nil { + errors.LogDebug(context.Background(), rec.addr, " ", rec.id, " ", rec.seq, " xicmp wireformat err ", err) continue } @@ -297,38 +312,36 @@ func (c *xicmpConnServer) sendLoop() { _, err = c.icmpConn.WriteTo(buf, &net.IPAddr{IP: rec.addr.(*net.UDPAddr).IP}) if err != nil { - errors.LogDebug(context.Background(), "xicmp writeto err ", err) + errors.LogDebug(context.Background(), rec.addr, " ", rec.id, " ", rec.seq, " xicmp writeto err ", err) } } } -func (c *xicmpConnServer) Size() int32 { - return 0 -} - func (c *xicmpConnServer) ReadFrom(p []byte) (n int, addr net.Addr, err error) { packet, ok := <-c.readQueue if !ok { - return 0, nil, io.EOF + return 0, nil, net.ErrClosed } - n = copy(p, packet.p) - if n != len(packet.p) { - return 0, nil, io.ErrShortBuffer + if len(p) < len(packet.p) { + errors.LogDebug(context.Background(), packet.addr, " mask read err short buffer ", len(p), " ", len(packet.p)) + return 0, packet.addr, nil } - return n, packet.addr, nil + copy(p, packet.p) + return len(packet.p), packet.addr, nil } func (c *xicmpConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) { - q := c.ensureQueue(addr) - if q == nil { - return 0, errors.New("xicmp closed") + if len(p)+8+1 > finalmask.UDPSize { + errors.LogDebug(context.Background(), addr, " mask write err short write ", len(p), "+8+1 > ", finalmask.UDPSize) + return 0, nil } c.mutex.Lock() defer c.mutex.Unlock() - if c.closed { - return 0, errors.New("xicmp closed") + q := c.ensureQueue(addr) + if q == nil { + return 0, io.ErrClosedPipe } buf := make([]byte, len(p)) @@ -338,24 +351,13 @@ func (c *xicmpConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) { case q.queue <- buf: return len(p), nil default: - return 0, errors.New("xicmp queue full") + // errors.LogDebug(context.Background(), addr, " mask write err queue full") + return 0, nil } } func (c *xicmpConnServer) Close() error { - c.mutex.Lock() - defer c.mutex.Unlock() - - if c.closed { - return nil - } - c.closed = true - for key, q := range c.writeQueueMap { - close(q.queue) - delete(c.writeQueueMap, key) - } - _ = c.icmpConn.Close() return c.conn.Close() } diff --git a/transport/internet/grpc/dial.go b/transport/internet/grpc/dial.go index e0b0aa03f8a7..0decaf9b2e7a 100644 --- a/transport/internet/grpc/dial.go +++ b/transport/internet/grpc/dial.go @@ -125,6 +125,15 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in c, err := internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt) if err == nil { + if streamSettings.TcpmaskManager != nil { + newConn, err := streamSettings.TcpmaskManager.WrapConnClient(c) + if err != nil { + c.Close() + return nil, errors.New("mask err").Base(err) + } + c = newConn + } + if tlsConfig != nil { config := tlsConfig.GetTLSConfig() if config.ServerName == "" && address.Family().IsDomain() { diff --git a/transport/internet/grpc/hub.go b/transport/internet/grpc/hub.go index ae8788fab5c8..34f8f1c0e60d 100644 --- a/transport/internet/grpc/hub.go +++ b/transport/internet/grpc/hub.go @@ -120,6 +120,10 @@ func Listen(ctx context.Context, address net.Address, port net.Port, settings *i } } + if settings.TcpmaskManager != nil { + streamListener, _ = settings.TcpmaskManager.WrapListener(streamListener) + } + errors.LogDebug(ctx, "gRPC listen for service name `"+grpcSettings.getServiceName()+"` tun `"+grpcSettings.getTunStreamName()+"` multi tun `"+grpcSettings.getTunMultiStreamName()+"`") encoding.RegisterGRPCServiceServerX(s, listener, grpcSettings.getServiceName(), grpcSettings.getTunStreamName(), grpcSettings.getTunMultiStreamName()) diff --git a/transport/internet/httpupgrade/dialer.go b/transport/internet/httpupgrade/dialer.go index 4d718eb8036e..eacbded4edb2 100644 --- a/transport/internet/httpupgrade/dialer.go +++ b/transport/internet/httpupgrade/dialer.go @@ -52,6 +52,15 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings * return nil, err } + if streamSettings.TcpmaskManager != nil { + newConn, err := streamSettings.TcpmaskManager.WrapConnClient(pconn) + if err != nil { + pconn.Close() + return nil, errors.New("mask err").Base(err) + } + pconn = newConn + } + var conn net.Conn var requestURL url.URL tConfig := tls.ConfigFromStreamSettings(streamSettings) diff --git a/transport/internet/httpupgrade/hub.go b/transport/internet/httpupgrade/hub.go index 778a35e5e0e0..8e70ad08a735 100644 --- a/transport/internet/httpupgrade/hub.go +++ b/transport/internet/httpupgrade/hub.go @@ -142,6 +142,10 @@ func ListenHTTPUpgrade(ctx context.Context, address net.Address, port net.Port, errors.LogInfo(ctx, "listening TCP(for HttpUpgrade) on ", address, ":", port) } + if streamSettings.TcpmaskManager != nil { + listener, _ = streamSettings.TcpmaskManager.WrapListener(listener) + } + if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol { errors.LogWarning(ctx, "accepting PROXY protocol") } diff --git a/transport/internet/hysteria/conn.go b/transport/internet/hysteria/conn.go index be4b0f595cb5..cf0920d8a9c3 100644 --- a/transport/internet/hysteria/conn.go +++ b/transport/internet/hysteria/conn.go @@ -34,6 +34,7 @@ func (i *interConn) Read(b []byte) (int, error) { func (i *interConn) Write(b []byte) (int, error) { if i.client { i.mutex.Lock() + defer i.mutex.Unlock() if i.client { buf := make([]byte, 0, quicvarint.Len(FrameTypeTCPRequest)+len(b)) buf = quicvarint.Append(buf, FrameTypeTCPRequest) @@ -45,7 +46,6 @@ func (i *interConn) Write(b []byte) (int, error) { i.client = false return len(b), nil } - i.mutex.Unlock() } return i.stream.Write(b) diff --git a/transport/internet/hysteria/dialer.go b/transport/internet/hysteria/dialer.go index c1e8d150e1aa..9a4c6374b0c3 100644 --- a/transport/internet/hysteria/dialer.go +++ b/transport/internet/hysteria/dialer.go @@ -7,6 +7,7 @@ import ( "math/rand" "net/http" "net/url" + "reflect" "strconv" "sync" "time" @@ -16,6 +17,7 @@ import ( "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/net/cnc" "github.com/xtls/xray-core/common/task" hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx" "github.com/xtls/xray-core/transport/internet" @@ -162,17 +164,33 @@ func (c *client) dial() error { return errors.New("failed to dial to dest").Base(err) } - remote := raw.RemoteAddr() + var pktConn net.PacketConn + var remote *net.UDPAddr - pktConn, ok := raw.(net.PacketConn) - if !ok { + switch conn := raw.(type) { + case *internet.PacketConnWrapper: + pktConn = conn.PacketConn + remote = conn.RemoteAddr().(*net.UDPAddr) + case *net.UDPConn: + pktConn = conn + remote = conn.RemoteAddr().(*net.UDPAddr) + case *cnc.Connection: + fakeConn := &internet.FakePacketConn{Conn: conn} + pktConn = fakeConn + remote = fakeConn.RemoteAddr().(*net.UDPAddr) + + if len(c.config.Ports) > 0 { + raw.Close() + return errors.New("udphop requires being at the outermost level") + } + default: raw.Close() - return errors.New("raw is not PacketConn") + return errors.New("unknown conn ", reflect.TypeOf(conn)) } if len(c.config.Ports) > 0 { addr := &udphop.UDPHopAddr{ - IP: remote.(*net.UDPAddr).IP, + IP: remote.IP, Ports: c.config.Ports, } pktConn, err = udphop.NewUDPHopPacketConn(addr, c.config.IntervalMin, c.config.IntervalMax, c.udphopDialer, pktConn, index) @@ -341,20 +359,28 @@ func (c *client) udphopDialer(addr *net.UDPAddr) (net.PacketConn, error) { defer c.mutex.Unlock() if c.status() != StatusActive { - errors.LogDebug(c.ctx, "stop hop on disconnected QUIC waiting to be closed") + errors.LogDebug(c.ctx, "skip hop: disconnected QUIC") return nil, errors.New() } - raw, err := internet.DialSystem(c.ctx, net.DestinationFromAddr(addr), c.socketConfig) + raw, err := internet.DialSystem(c.ctx, net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)), c.socketConfig) if err != nil { - errors.LogDebug(c.ctx, "failed to dial to dest skip hop") + errors.LogDebug(c.ctx, "skip hop: failed to dial to dest") return nil, errors.New() } - pktConn, ok := raw.(net.PacketConn) - if !ok { - errors.LogDebug(c.ctx, "raw is not PacketConn skip hop") - raw.Close() + var pktConn net.PacketConn + + switch conn := raw.(type) { + case *internet.PacketConnWrapper: + pktConn = conn.PacketConn + case *net.UDPConn: + pktConn = conn + case *cnc.Connection: + errors.LogDebug(c.ctx, "skip hop: udphop requires being at the outermost level") + return nil, errors.New() + default: + errors.LogDebug(c.ctx, "skip hop: unknown conn ", reflect.TypeOf(conn)) return nil, errors.New() } diff --git a/transport/internet/hysteria/udphop/conn.go b/transport/internet/hysteria/udphop/conn.go index c19de476150f..663b27e5bdf7 100644 --- a/transport/internet/hysteria/udphop/conn.go +++ b/transport/internet/hysteria/udphop/conn.go @@ -18,7 +18,7 @@ const ( defaultHopInterval = 30 * time.Second ) -type udpHopPacketConn struct { +type UdpHopPacketConn struct { Addr net.Addr Addrs []net.Addr HopIntervalMin int64 @@ -73,7 +73,7 @@ func NewUDPHopPacketConn(addr *UDPHopAddr, intervalMin int64, intervalMax int64, // if err != nil { // return nil, err // } - hConn := &udpHopPacketConn{ + hConn := &UdpHopPacketConn{ Addr: addr, Addrs: addrs, HopIntervalMin: intervalMin, @@ -95,7 +95,7 @@ func NewUDPHopPacketConn(addr *UDPHopAddr, intervalMin int64, intervalMax int64, return hConn, nil } -func (u *udpHopPacketConn) recvLoop(conn net.PacketConn) { +func (u *UdpHopPacketConn) recvLoop(conn net.PacketConn) { for { buf := u.bufPool.Get().([]byte) n, addr, err := conn.ReadFrom(buf) @@ -120,7 +120,7 @@ func (u *udpHopPacketConn) recvLoop(conn net.PacketConn) { } } -func (u *udpHopPacketConn) hopLoop() { +func (u *UdpHopPacketConn) hopLoop() { ticker := time.NewTicker(time.Duration(crypto.RandBetween(u.HopIntervalMin, u.HopIntervalMax)) * time.Second) defer ticker.Stop() for { @@ -134,7 +134,7 @@ func (u *udpHopPacketConn) hopLoop() { } } -func (u *udpHopPacketConn) hop() { +func (u *UdpHopPacketConn) hop() { u.connMutex.Lock() defer u.connMutex.Unlock() if u.closed { @@ -170,7 +170,7 @@ func (u *udpHopPacketConn) hop() { go u.recvLoop(newConn) } -func (u *udpHopPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { +func (u *UdpHopPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { for { select { case p := <-u.recvQueue: @@ -188,7 +188,7 @@ func (u *udpHopPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) } } -func (u *udpHopPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { +func (u *UdpHopPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { u.connMutex.RLock() defer u.connMutex.RUnlock() if u.closed { @@ -199,7 +199,7 @@ func (u *udpHopPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return u.currentConn.WriteTo(b, u.Addrs[u.addrIndex]) } -func (u *udpHopPacketConn) Close() error { +func (u *UdpHopPacketConn) Close() error { u.connMutex.Lock() defer u.connMutex.Unlock() if u.closed { @@ -218,13 +218,13 @@ func (u *udpHopPacketConn) Close() error { return err } -func (u *udpHopPacketConn) LocalAddr() net.Addr { +func (u *UdpHopPacketConn) LocalAddr() net.Addr { u.connMutex.RLock() defer u.connMutex.RUnlock() return u.currentConn.LocalAddr() } -func (u *udpHopPacketConn) SetDeadline(t time.Time) error { +func (u *UdpHopPacketConn) SetDeadline(t time.Time) error { u.connMutex.RLock() defer u.connMutex.RUnlock() if u.prevConn != nil { @@ -233,7 +233,7 @@ func (u *udpHopPacketConn) SetDeadline(t time.Time) error { return u.currentConn.SetDeadline(t) } -func (u *udpHopPacketConn) SetReadDeadline(t time.Time) error { +func (u *UdpHopPacketConn) SetReadDeadline(t time.Time) error { u.connMutex.RLock() defer u.connMutex.RUnlock() if u.prevConn != nil { @@ -242,7 +242,7 @@ func (u *udpHopPacketConn) SetReadDeadline(t time.Time) error { return u.currentConn.SetReadDeadline(t) } -func (u *udpHopPacketConn) SetWriteDeadline(t time.Time) error { +func (u *UdpHopPacketConn) SetWriteDeadline(t time.Time) error { u.connMutex.RLock() defer u.connMutex.RUnlock() if u.prevConn != nil { @@ -253,7 +253,7 @@ func (u *udpHopPacketConn) SetWriteDeadline(t time.Time) error { // UDP-specific methods below -func (u *udpHopPacketConn) SetReadBuffer(bytes int) error { +func (u *UdpHopPacketConn) SetReadBuffer(bytes int) error { u.connMutex.Lock() defer u.connMutex.Unlock() u.readBufferSize = bytes @@ -263,7 +263,7 @@ func (u *udpHopPacketConn) SetReadBuffer(bytes int) error { return trySetReadBuffer(u.currentConn, bytes) } -func (u *udpHopPacketConn) SetWriteBuffer(bytes int) error { +func (u *UdpHopPacketConn) SetWriteBuffer(bytes int) error { u.connMutex.Lock() defer u.connMutex.Unlock() u.writeBufferSize = bytes @@ -273,7 +273,7 @@ func (u *udpHopPacketConn) SetWriteBuffer(bytes int) error { return trySetWriteBuffer(u.currentConn, bytes) } -func (u *udpHopPacketConn) SyscallConn() (syscall.RawConn, error) { +func (u *UdpHopPacketConn) SyscallConn() (syscall.RawConn, error) { u.connMutex.RLock() defer u.connMutex.RUnlock() sc, ok := u.currentConn.(syscall.Conn) diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index a49155886b22..310bbd5386e7 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -3,6 +3,7 @@ package kcp import ( "context" "io" + reflect "reflect" "sync/atomic" "github.com/xtls/xray-core/common" @@ -10,6 +11,7 @@ import ( "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/net/cnc" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tls" @@ -49,24 +51,47 @@ func DialKCP(ctx context.Context, dest net.Destination, streamSettings *internet dest.Network = net.Network_UDP errors.LogInfo(ctx, "dialing mKCP to ", dest) - rawConn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) + conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) if err != nil { return nil, errors.New("failed to dial to dest: ", err).AtWarning().Base(err) } if streamSettings.UdpmaskManager != nil { - wrapper, ok := rawConn.(*internet.PacketConnWrapper) - if !ok { - rawConn.Close() - return nil, errors.New("raw is not PacketConnWrapper") - } - - raw := wrapper.Conn - - wrapper.Conn, err = streamSettings.UdpmaskManager.WrapPacketConnClient(raw) - if err != nil { - raw.Close() - return nil, errors.New("mask err").Base(err) + switch c := conn.(type) { + case *internet.PacketConnWrapper: + pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c.PacketConn) + if err != nil { + conn.Close() + return nil, errors.New("mask err").Base(err) + } + c.PacketConn = pktConn + case *net.UDPConn: + pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c) + if err != nil { + conn.Close() + return nil, errors.New("mask err").Base(err) + } + conn = &internet.PacketConnWrapper{ + PacketConn: pktConn, + Dest: c.RemoteAddr().(*net.UDPAddr), + } + case *cnc.Connection: + fakeConn := &internet.FakePacketConn{Conn: c} + pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(fakeConn) + if err != nil { + conn.Close() + return nil, errors.New("mask err").Base(err) + } + conn = &internet.PacketConnWrapper{ + PacketConn: pktConn, + Dest: &net.UDPAddr{ + IP: []byte{0, 0, 0, 0}, + Port: 0, + }, + } + default: + conn.Close() + return nil, errors.New("unknown conn ", reflect.TypeOf(c)) } } @@ -76,12 +101,12 @@ func DialKCP(ctx context.Context, dest net.Destination, streamSettings *internet conv := uint16(atomic.AddUint32(&globalConv, 1)) session := NewConnection(ConnMetadata{ - LocalAddr: rawConn.LocalAddr(), - RemoteAddr: rawConn.RemoteAddr(), + LocalAddr: conn.LocalAddr(), + RemoteAddr: conn.RemoteAddr(), Conversation: conv, - }, rawConn, rawConn, kcpSettings) + }, conn, conn, kcpSettings) - go fetchInput(ctx, rawConn, reader, session) + go fetchInput(ctx, conn, reader, session) var iConn stat.Connection = session diff --git a/transport/internet/kcp/receiving.go b/transport/internet/kcp/receiving.go index a75014aff789..3d1fc01438f1 100644 --- a/transport/internet/kcp/receiving.go +++ b/transport/internet/kcp/receiving.go @@ -47,15 +47,18 @@ type AckList struct { flushCandidates []uint32 dirty bool + + mss uint32 } -func NewAckList(writer SegmentWriter) *AckList { +func NewAckList(writer SegmentWriter, mss uint32) *AckList { return &AckList{ writer: writer, timestamps: make([]uint32, 0, 128), numbers: make([]uint32, 0, 128), nextFlush: make([]uint32, 0, 128), flushCandidates: make([]uint32, 0, 128), + mss: mss, } } @@ -90,7 +93,7 @@ func (l *AckList) Clear(una uint32) { func (l *AckList) Flush(current uint32, rto uint32) { l.flushCandidates = l.flushCandidates[:0] - seg := NewAckSegment() + seg := NewAckSegment((int(l.mss) - 17) / 4) for i := 0; i < len(l.numbers); i++ { if l.nextFlush[i] > current { if len(l.flushCandidates) < cap(l.flushCandidates) { @@ -109,7 +112,7 @@ func (l *AckList) Flush(current uint32, rto uint32) { if seg.IsFull() { l.writer.Write(seg) seg.Release() - seg = NewAckSegment() + seg = NewAckSegment((int(l.mss) - 17) / 4) l.dirty = false } } @@ -144,7 +147,7 @@ func NewReceivingWorker(kcp *Connection) *ReceivingWorker { window: NewReceivingWindow(), windowSize: kcp.Config.GetReceivingInFlightSize(), } - worker.acklist = NewAckList(worker) + worker.acklist = NewAckList(worker, kcp.mss+DataSegmentOverhead) return worker } diff --git a/transport/internet/kcp/segment.go b/transport/internet/kcp/segment.go index b97d25ea9366..2beaa0b67275 100644 --- a/transport/internet/kcp/segment.go +++ b/transport/internet/kcp/segment.go @@ -131,12 +131,22 @@ type AckSegment struct { ReceivingNext uint32 Timestamp uint32 NumberList []uint32 + + Limit int } const ackNumberLimit = 128 -func NewAckSegment() *AckSegment { - return new(AckSegment) +func NewAckSegment(limit int) *AckSegment { + if limit <= 0 { + limit = 1 + } + if limit > ackNumberLimit { + limit = ackNumberLimit + } + return &AckSegment{ + Limit: limit, + } } func (s *AckSegment) parse(conv uint16, cmd Command, opt SegmentOption, buf []byte) (bool, []byte) { @@ -188,7 +198,7 @@ func (s *AckSegment) PutNumber(number uint32) { } func (s *AckSegment) IsFull() bool { - return len(s.NumberList) == ackNumberLimit + return len(s.NumberList) == s.Limit } func (s *AckSegment) IsEmpty() bool { @@ -290,7 +300,7 @@ func ReadSegment(buf []byte) (Segment, []byte) { case CommandData: seg = NewDataSegment() case CommandACK: - seg = NewAckSegment() + seg = NewAckSegment(128) default: seg = NewCmdOnlySegment() } diff --git a/transport/internet/kcp/segment_test.go b/transport/internet/kcp/segment_test.go index daa9098a77c6..cc12ea9bd03a 100644 --- a/transport/internet/kcp/segment_test.go +++ b/transport/internet/kcp/segment_test.go @@ -71,6 +71,7 @@ func TestACKSegment(t *testing.T) { ReceivingNext: 3, Timestamp: 10, NumberList: []uint32{1, 3, 5, 7, 9}, + Limit: 128, } nBytes := seg.ByteSize() diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 4d02a67169c2..6f39c20daaf0 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -117,6 +117,15 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea return nil, err } + if streamSettings.TcpmaskManager != nil { + newConn, err := streamSettings.TcpmaskManager.WrapConnClient(conn) + if err != nil { + conn.Close() + return nil, errors.New("mask err").Base(err) + } + conn = newConn + } + if realityConfig != nil { return reality.UClient(conn, realityConfig, ctxInner, dest) } @@ -173,7 +182,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea switch c := conn.(type) { case *internet.PacketConnWrapper: var ok bool - udpConn, ok = c.Conn.(*net.UDPConn) + udpConn, ok = c.PacketConn.(*net.UDPConn) if !ok { return nil, errors.New("PacketConnWrapper does not contain a UDP connection") } @@ -195,6 +204,15 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea } } + if streamSettings.UdpmaskManager != nil { + pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(udpConn) + if err != nil { + udpConn.Close() + return nil, errors.New("mask err").Base(err) + } + udpConn = pktConn + } + return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg) }, } diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 7e15724646c5..4f0832cb10ee 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -472,6 +472,16 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet if err != nil { return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err) } + + if streamSettings.UdpmaskManager != nil { + pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnServer(Conn) + if err != nil { + Conn.Close() + return nil, errors.New("mask err").Base(err) + } + Conn = pktConn + } + l.h3listener, err = quic.ListenEarly(Conn, tlsConfig, nil) if err != nil { return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err) @@ -499,6 +509,10 @@ func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSet errors.LogInfo(ctx, "listening TCP for XHTTP on ", address, ":", port) } + if !l.isH3 && streamSettings.TcpmaskManager != nil { + l.listener, _ = streamSettings.TcpmaskManager.WrapListener(l.listener) + } + // tcp/unix (h1/h2) if l.listener != nil { if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index 015675762eec..16b3e9b0b414 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -2,7 +2,6 @@ package internet import ( "context" - "math/rand" "syscall" "time" @@ -83,8 +82,8 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne return nil, err } return &PacketConnWrapper{ - Conn: packetConn, - Dest: destAddr, + PacketConn: packetConn, + Dest: destAddr, }, nil } // Chrome defaults @@ -150,57 +149,21 @@ func (d *DefaultSystemDialer) DestIpAddress() net.IP { } type PacketConnWrapper struct { - Conn net.PacketConn + net.PacketConn Dest net.Addr } -func (c *PacketConnWrapper) Close() error { - return c.Conn.Close() -} - -func (c *PacketConnWrapper) LocalAddr() net.Addr { - return c.Conn.LocalAddr() -} - -func (c *PacketConnWrapper) RemoteAddr() net.Addr { - return c.Dest -} - -func (c *PacketConnWrapper) Write(p []byte) (int, error) { - return c.Conn.WriteTo(p, c.Dest) -} - func (c *PacketConnWrapper) Read(p []byte) (int, error) { - n, _, err := c.Conn.ReadFrom(p) + n, _, err := c.PacketConn.ReadFrom(p) return n, err } -func (c *PacketConnWrapper) WriteTo(p []byte, d net.Addr) (int, error) { - return c.Conn.WriteTo(p, d) -} - -func (c *PacketConnWrapper) ReadFrom(p []byte) (int, net.Addr, error) { - return c.Conn.ReadFrom(p) -} - -func (c *PacketConnWrapper) SetDeadline(t time.Time) error { - return c.Conn.SetDeadline(t) -} - -func (c *PacketConnWrapper) SetReadDeadline(t time.Time) error { - return c.Conn.SetReadDeadline(t) -} - -func (c *PacketConnWrapper) SetWriteDeadline(t time.Time) error { - return c.Conn.SetWriteDeadline(t) +func (c *PacketConnWrapper) Write(p []byte) (int, error) { + return c.PacketConn.WriteTo(p, c.Dest) } -func (c *PacketConnWrapper) SyscallConn() (syscall.RawConn, error) { - sc, ok := c.Conn.(syscall.Conn) - if !ok { - return nil, syscall.EINVAL - } - return sc.SyscallConn() +func (c *PacketConnWrapper) RemoteAddr() net.Addr { + return c.Dest } type SystemDialerAdapter interface { @@ -269,14 +232,15 @@ func (c *FakePacketConn) WriteTo(p []byte, _ net.Addr) (n int, err error) { } func (c *FakePacketConn) LocalAddr() net.Addr { - return &net.TCPAddr{ - IP: net.IP{byte(rand.Intn(256)), byte(rand.Intn(256)), byte(rand.Intn(256)), byte(rand.Intn(256))}, - Port: rand.Intn(65536), + return &net.UDPAddr{ + IP: []byte{0, 0, 0, 0}, + Port: 0, } } -func (c *FakePacketConn) SetReadBuffer(bytes int) error { - // do nothing, this function is only there to suppress quic-go printing - // random warnings about UDP buffers to stdout - return nil +func (c *FakePacketConn) RemoteAddr() net.Addr { + return &net.UDPAddr{ + IP: []byte{0, 0, 0, 0}, + Port: 0, + } } diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index 5b966a004e0c..92fa7557f13a 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -24,6 +24,15 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me return nil, err } + if streamSettings.TcpmaskManager != nil { + newConn, err := streamSettings.TcpmaskManager.WrapConnClient(conn) + if err != nil { + conn.Close() + return nil, errors.New("mask err").Base(err) + } + conn = newConn + } + if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { mitmServerName := session.MitmServerNameFromContext(ctx) mitmAlpn11 := session.MitmAlpn11FromContext(ctx) diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 759dfc35a6b7..ede97499dac9 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -64,6 +64,10 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSe errors.LogInfo(ctx, "listening TCP on ", address, ":", port) } + if streamSettings.TcpmaskManager != nil { + listener, _ = streamSettings.TcpmaskManager.WrapListener(listener) + } + if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol { errors.LogWarning(ctx, "accepting PROXY protocol") } @@ -108,6 +112,7 @@ func (v *Listener) keepAccepting() { } continue } + go func() { if v.tlsConfig != nil { conn = tls.Server(conn, v.tlsConfig) diff --git a/transport/internet/udp/dialer.go b/transport/internet/udp/dialer.go index af25eb338c0a..c930c3551dc2 100644 --- a/transport/internet/udp/dialer.go +++ b/transport/internet/udp/dialer.go @@ -2,10 +2,12 @@ package udp import ( "context" + reflect "reflect" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/net/cnc" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/stat" ) @@ -23,18 +25,41 @@ func init() { } if streamSettings != nil && streamSettings.UdpmaskManager != nil { - wrapper, ok := conn.(*internet.PacketConnWrapper) - if !ok { + switch c := conn.(type) { + case *internet.PacketConnWrapper: + pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c.PacketConn) + if err != nil { + conn.Close() + return nil, errors.New("mask err").Base(err) + } + c.PacketConn = pktConn + case *net.UDPConn: + pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(c) + if err != nil { + conn.Close() + return nil, errors.New("mask err").Base(err) + } + conn = &internet.PacketConnWrapper{ + PacketConn: pktConn, + Dest: c.RemoteAddr().(*net.UDPAddr), + } + case *cnc.Connection: + fakeConn := &internet.FakePacketConn{Conn: c} + pktConn, err := streamSettings.UdpmaskManager.WrapPacketConnClient(fakeConn) + if err != nil { + conn.Close() + return nil, errors.New("mask err").Base(err) + } + conn = &internet.PacketConnWrapper{ + PacketConn: pktConn, + Dest: &net.UDPAddr{ + IP: []byte{0, 0, 0, 0}, + Port: 0, + }, + } + default: conn.Close() - return nil, errors.New("conn is not PacketConnWrapper") - } - - raw := wrapper.Conn - - wrapper.Conn, err = streamSettings.UdpmaskManager.WrapPacketConnClient(raw) - if err != nil { - raw.Close() - return nil, errors.New("mask err").Base(err) + return nil, errors.New("unknown conn ", reflect.TypeOf(c)) } } diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 5e41389304a1..e5354908d77c 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -48,7 +48,21 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in dialer := &websocket.Dialer{ NetDial: func(network, addr string) (net.Conn, error) { - return internet.DialSystem(ctx, dest, streamSettings.SocketSettings) + conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) + if err != nil { + return nil, err + } + + if streamSettings.TcpmaskManager != nil { + newConn, err := streamSettings.TcpmaskManager.WrapConnClient(conn) + if err != nil { + conn.Close() + return nil, errors.New("mask err").Base(err) + } + conn = newConn + } + + return conn, err }, ReadBufferSize: 4 * 1024, WriteBufferSize: 4 * 1024, @@ -70,6 +84,16 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in errors.LogErrorInner(ctx, err, "failed to dial to "+addr) return nil, err } + + if streamSettings.TcpmaskManager != nil { + newConn, err := streamSettings.TcpmaskManager.WrapConnClient(pconn) + if err != nil { + pconn.Close() + return nil, errors.New("mask err").Base(err) + } + pconn = newConn + } + // TLS and apply the handshake cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn) if err := cn.WebsocketHandshakeContext(ctx); err != nil { diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index b47e7d581770..73799174ca24 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -129,6 +129,10 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet errors.LogInfo(ctx, "listening TCP(for WS) on ", address, ":", port) } + if streamSettings.TcpmaskManager != nil { + listener, _ = streamSettings.TcpmaskManager.WrapListener(listener) + } + if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol { errors.LogWarning(ctx, "accepting PROXY protocol") }