diff --git a/app/connectiontracker/access.go b/app/connectiontracker/access.go new file mode 100644 index 000000000000..2b05b90085a7 --- /dev/null +++ b/app/connectiontracker/access.go @@ -0,0 +1,193 @@ +package connectiontracker + +import ( + "context" + "sync/atomic" + "time" + + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" + clog "github.com/xtls/xray-core/common/log" + "github.com/xtls/xray-core/transport" +) + +type accessRecordKey struct{} + +// AccessRecord captures accepted-request access-log state until the request +// finishes and a final log line can be emitted. +type AccessRecord struct { + ID uint32 + + Msg *clog.AccessMessage + + RequestBytes int64 + ResponseBytes int64 + + LastActivity int64 + + cancel context.CancelFunc + finished atomic.Bool +} + +// ContextWithAccessRecord stores r in ctx for deferred access-log handling. +func ContextWithAccessRecord(ctx context.Context, r *AccessRecord) context.Context { + return context.WithValue(ctx, accessRecordKey{}, r) +} + +// AccessRecordFromContext returns the access record stored in ctx, if any. +func AccessRecordFromContext(ctx context.Context) *AccessRecord { + r, _ := ctx.Value(accessRecordKey{}).(*AccessRecord) + return r +} + +func (r *AccessRecord) touch() { + atomic.StoreInt64(&r.LastActivity, time.Now().UnixNano()) +} + +func (r *AccessRecord) addRequestBytes(n int64) { + if n <= 0 { + return + } + atomic.AddInt64(&r.RequestBytes, n) + r.touch() +} + +func (r *AccessRecord) addResponseBytes(n int64) { + if n <= 0 { + return + } + atomic.AddInt64(&r.ResponseBytes, n) + r.touch() +} + +func cloneAccessMessage(msg *clog.AccessMessage) *clog.AccessMessage { + if msg == nil { + return nil + } + cloned := *msg + return &cloned +} + +func (m *Manager) completeAccessRecord(r *AccessRecord, reason error) { + if r == nil || !r.finished.CompareAndSwap(false, true) { + return + } + msg := cloneAccessMessage(r.Msg) + if msg == nil { + return + } + if reason != nil { + msg.Reason = reason + } + msg.RequestBytes = atomic.LoadInt64(&r.RequestBytes) + msg.ResponseBytes = atomic.LoadInt64(&r.ResponseBytes) + clog.Record(msg) +} + +// NewAccessRecord creates a new access record for an accepted request. +func (m *Manager) NewAccessRecord(msg *clog.AccessMessage, cancel context.CancelFunc) *AccessRecord { + if msg == nil { + return nil + } + record := &AccessRecord{ + ID: atomic.AddUint32(&m.globalNext, 1), + Msg: msg, + cancel: cancel, + } + record.touch() + return record +} + +// FinishAccessRecord emits the final access log for r, including payload +// totals accumulated during the request lifetime. +func (m *Manager) FinishAccessRecord(r *AccessRecord) { + m.completeAccessRecord(r, nil) +} + +// AbortAccessRecord emits the final access log for r with an abort reason and +// cancels the tracked context if one was supplied. +func (m *Manager) AbortAccessRecord(r *AccessRecord, reason error) { + if r != nil && r.cancel != nil { + r.cancel() + } + m.completeAccessRecord(r, reason) +} + +// TrackAccessLink stores a deferred access record in ctx and wraps link so +// payload bytes can be accounted for at the body reader/writer boundary. +func (m *Manager) TrackAccessLink(ctx context.Context, msg *clog.AccessMessage, link *transport.Link, cancel context.CancelFunc) (context.Context, *transport.Link, *AccessRecord) { + record := m.NewAccessRecord(msg, cancel) + if record == nil || link == nil { + return ctx, link, record + } + ctx = ContextWithAccessRecord(ctx, record) + link = WrapAccessLink(link, record) + return ctx, link, record +} + +// WrapAccessLink wraps link so payload bytes are attributed to record. +func WrapAccessLink(link *transport.Link, record *AccessRecord) *transport.Link { + if link == nil || record == nil { + return link + } + if link.Reader != nil { + link.Reader = &TrackedAccessReader{Reader: link.Reader, record: record} + } + if link.Writer != nil { + link.Writer = &TrackedAccessWriter{Writer: link.Writer, record: record} + } + return link +} + +// TrackedAccessReader counts request payload bytes read by an accepted request. +type TrackedAccessReader struct { + buf.Reader + record *AccessRecord +} + +func (r *TrackedAccessReader) ReadMultiBuffer() (buf.MultiBuffer, error) { + mb, err := r.Reader.ReadMultiBuffer() + if n := int64(mb.Len()); n > 0 && r.record != nil { + r.record.addRequestBytes(n) + } + return mb, err +} + +func (r *TrackedAccessReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) { + if reader, ok := r.Reader.(buf.TimeoutReader); ok { + mb, err := reader.ReadMultiBufferTimeout(timeout) + if n := int64(mb.Len()); n > 0 && r.record != nil { + r.record.addRequestBytes(n) + } + return mb, err + } + return r.ReadMultiBuffer() +} + +func (r *TrackedAccessReader) Interrupt() { + common.Interrupt(r.Reader) +} + +func (r *TrackedAccessReader) Close() error { + return common.Close(r.Reader) +} + +// TrackedAccessWriter counts response payload bytes written by an accepted +// request. +type TrackedAccessWriter struct { + buf.Writer + record *AccessRecord +} + +func (w *TrackedAccessWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { + n := int64(mb.Len()) + err := w.Writer.WriteMultiBuffer(mb) + if err == nil && n > 0 && w.record != nil { + w.record.addResponseBytes(n) + } + return err +} + +func (w *TrackedAccessWriter) Close() error { + return common.Close(w.Writer) +} diff --git a/app/connectiontracker/command/command.go b/app/connectiontracker/command/command.go new file mode 100644 index 000000000000..69b4d85a6096 --- /dev/null +++ b/app/connectiontracker/command/command.go @@ -0,0 +1,105 @@ +package command + +import ( + "context" + "strings" + + grpc "google.golang.org/grpc" + + "github.com/xtls/xray-core/app/connectiontracker" + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/core" +) + +type connTrackerServer struct { + UnimplementedConnTrackerServiceServer + manager *connectiontracker.Manager +} + +func (s *connTrackerServer) ListConnections(_ context.Context, _ *ListConnectionsRequest) (*ListConnectionsResponse, error) { + all := s.manager.ListAllConnections() + resp := &ListConnectionsResponse{ + Connections: make([]*ConnInfo, 0, len(all)), + } + for _, c := range all { + resp.Connections = append(resp.Connections, toProto(c)) + } + return resp, nil +} + +func (s *connTrackerServer) CloseConnection(_ context.Context, req *CloseConnectionRequest) (*CloseConnectionResponse, error) { + found := s.manager.CloseGlobalConn(req.Id) + return &CloseConnectionResponse{Found: found}, nil +} + +func (s *connTrackerServer) GetUserStats(_ context.Context, req *GetUserStatsRequest) (*GetUserStatsResponse, error) { + up, down, count := s.manager.GetUserStats(strings.ToLower(req.Email)) + return &GetUserStatsResponse{ + Uplink: up, + Downlink: down, + ConnCount: count, + }, nil +} + +func (s *connTrackerServer) StreamConnections(_ *StreamConnectionsRequest, stream grpc.ServerStreamingServer[ConnectionUpdate]) error { + ch := s.manager.Subscribe() + defer s.manager.Unsubscribe(ch) + + ctx := stream.Context() + for { + select { + case <-ctx.Done(): + return ctx.Err() + case ev, ok := <-ch: + if !ok { + return nil + } + evType := ConnEventType_DISCONNECTED + if ev.Connected { + evType = ConnEventType_CONNECTED + } + if err := stream.Send(&ConnectionUpdate{ + Event: evType, + Conn: toProto(ev.Info), + }); err != nil { + return err + } + } + } +} + +func toProto(c connectiontracker.ConnectionInfo) *ConnInfo { + return &ConnInfo{ + Id: c.ID, + Email: c.Email, + InboundTag: c.InboundTag, + Protocol: c.Protocol, + StartTime: c.StartTime.Unix(), + LastActivity: c.LastActivity.Unix(), + Uplink: c.Uplink, + Downlink: c.Downlink, + } +} + +type service struct { + manager *connectiontracker.Manager +} + +func (s *service) Register(server *grpc.Server) { + RegisterConnTrackerServiceServer(server, &connTrackerServer{manager: s.manager}) +} + +func init() { + common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, _ interface{}) (interface{}, error) { + s := new(service) + + if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error { + s.manager = trackerSvc.Manager() + return nil + }); err != nil { + return nil, err + } + + return s, nil + })) +} diff --git a/app/connectiontracker/command/command.pb.go b/app/connectiontracker/command/command.pb.go new file mode 100644 index 000000000000..94b379eed891 --- /dev/null +++ b/app/connectiontracker/command/command.pb.go @@ -0,0 +1,677 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.33.5 +// source: app/connectiontracker/command/command.proto + +package command + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ConnEventType int32 + +const ( + ConnEventType_CONNECTED ConnEventType = 0 + ConnEventType_DISCONNECTED ConnEventType = 1 +) + +var ( + ConnEventType_name = map[int32]string{ + 0: "CONNECTED", + 1: "DISCONNECTED", + } + ConnEventType_value = map[string]int32{ + "CONNECTED": 0, + "DISCONNECTED": 1, + } +) + +func (x ConnEventType) Enum() *ConnEventType { + p := new(ConnEventType) + *p = x + return p +} + +func (x ConnEventType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (ConnEventType) Descriptor() protoreflect.EnumDescriptor { + return file_app_connectiontracker_command_command_proto_enumTypes[0].Descriptor() +} + +func (ConnEventType) Type() protoreflect.EnumType { + return &file_app_connectiontracker_command_command_proto_enumTypes[0] +} + +func (x ConnEventType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +func (x *ConnEventType) UnmarshalJSON(b []byte) error { + num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b) + if err != nil { + return err + } + *x = ConnEventType(num) + return nil +} + +func (ConnEventType) EnumDescriptor() ([]byte, []int) { + return file_app_connectiontracker_command_command_proto_rawDescGZIP(), []int{0} +} + +type ConnInfo struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id uint32 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + Email string `protobuf:"bytes,2,opt,name=email,proto3" json:"email,omitempty"` + InboundTag string `protobuf:"bytes,3,opt,name=inbound_tag,json=inboundTag,proto3" json:"inbound_tag,omitempty"` + Protocol string `protobuf:"bytes,4,opt,name=protocol,proto3" json:"protocol,omitempty"` + StartTime int64 `protobuf:"varint,5,opt,name=start_time,json=startTime,proto3" json:"start_time,omitempty"` + LastActivity int64 `protobuf:"varint,6,opt,name=last_activity,json=lastActivity,proto3" json:"last_activity,omitempty"` + Uplink int64 `protobuf:"varint,7,opt,name=uplink,proto3" json:"uplink,omitempty"` + Downlink int64 `protobuf:"varint,8,opt,name=downlink,proto3" json:"downlink,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ConnInfo) Reset() { + *x = ConnInfo{} + mi := &file_app_connectiontracker_command_command_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ConnInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ConnInfo) ProtoMessage() {} + +func (x *ConnInfo) ProtoReflect() protoreflect.Message { + mi := &file_app_connectiontracker_command_command_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +func (*ConnInfo) Descriptor() ([]byte, []int) { + return file_app_connectiontracker_command_command_proto_rawDescGZIP(), []int{0} +} + +func (x *ConnInfo) GetId() uint32 { + if x != nil { + return x.Id + } + return 0 +} + +func (x *ConnInfo) GetEmail() string { + if x != nil { + return x.Email + } + return "" +} + +func (x *ConnInfo) GetInboundTag() string { + if x != nil { + return x.InboundTag + } + return "" +} + +func (x *ConnInfo) GetProtocol() string { + if x != nil { + return x.Protocol + } + return "" +} + +func (x *ConnInfo) GetStartTime() int64 { + if x != nil { + return x.StartTime + } + return 0 +} + +func (x *ConnInfo) GetLastActivity() int64 { + if x != nil { + return x.LastActivity + } + return 0 +} + +func (x *ConnInfo) GetUplink() int64 { + if x != nil { + return x.Uplink + } + return 0 +} + +func (x *ConnInfo) GetDownlink() int64 { + if x != nil { + return x.Downlink + } + return 0 +} + +type ListConnectionsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListConnectionsRequest) Reset() { + *x = ListConnectionsRequest{} + mi := &file_app_connectiontracker_command_command_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListConnectionsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListConnectionsRequest) ProtoMessage() {} + +func (x *ListConnectionsRequest) ProtoReflect() protoreflect.Message { + mi := &file_app_connectiontracker_command_command_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +func (*ListConnectionsRequest) Descriptor() ([]byte, []int) { + return file_app_connectiontracker_command_command_proto_rawDescGZIP(), []int{1} +} + +type ListConnectionsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Connections []*ConnInfo `protobuf:"bytes,1,rep,name=connections,proto3" json:"connections,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListConnectionsResponse) Reset() { + *x = ListConnectionsResponse{} + mi := &file_app_connectiontracker_command_command_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListConnectionsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListConnectionsResponse) ProtoMessage() {} + +func (x *ListConnectionsResponse) ProtoReflect() protoreflect.Message { + mi := &file_app_connectiontracker_command_command_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +func (*ListConnectionsResponse) Descriptor() ([]byte, []int) { + return file_app_connectiontracker_command_command_proto_rawDescGZIP(), []int{2} +} + +func (x *ListConnectionsResponse) GetConnections() []*ConnInfo { + if x != nil { + return x.Connections + } + return nil +} + +type CloseConnectionRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id uint32 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CloseConnectionRequest) Reset() { + *x = CloseConnectionRequest{} + mi := &file_app_connectiontracker_command_command_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CloseConnectionRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CloseConnectionRequest) ProtoMessage() {} + +func (x *CloseConnectionRequest) ProtoReflect() protoreflect.Message { + mi := &file_app_connectiontracker_command_command_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +func (*CloseConnectionRequest) Descriptor() ([]byte, []int) { + return file_app_connectiontracker_command_command_proto_rawDescGZIP(), []int{3} +} + +func (x *CloseConnectionRequest) GetId() uint32 { + if x != nil { + return x.Id + } + return 0 +} + +type CloseConnectionResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Found bool `protobuf:"varint,1,opt,name=found,proto3" json:"found,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CloseConnectionResponse) Reset() { + *x = CloseConnectionResponse{} + mi := &file_app_connectiontracker_command_command_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CloseConnectionResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CloseConnectionResponse) ProtoMessage() {} + +func (x *CloseConnectionResponse) ProtoReflect() protoreflect.Message { + mi := &file_app_connectiontracker_command_command_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +func (*CloseConnectionResponse) Descriptor() ([]byte, []int) { + return file_app_connectiontracker_command_command_proto_rawDescGZIP(), []int{4} +} + +func (x *CloseConnectionResponse) GetFound() bool { + if x != nil { + return x.Found + } + return false +} + +type GetUserStatsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Email string `protobuf:"bytes,1,opt,name=email,proto3" json:"email,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetUserStatsRequest) Reset() { + *x = GetUserStatsRequest{} + mi := &file_app_connectiontracker_command_command_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetUserStatsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetUserStatsRequest) ProtoMessage() {} + +func (x *GetUserStatsRequest) ProtoReflect() protoreflect.Message { + mi := &file_app_connectiontracker_command_command_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +func (*GetUserStatsRequest) Descriptor() ([]byte, []int) { + return file_app_connectiontracker_command_command_proto_rawDescGZIP(), []int{5} +} + +func (x *GetUserStatsRequest) GetEmail() string { + if x != nil { + return x.Email + } + return "" +} + +type GetUserStatsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Uplink int64 `protobuf:"varint,1,opt,name=uplink,proto3" json:"uplink,omitempty"` + Downlink int64 `protobuf:"varint,2,opt,name=downlink,proto3" json:"downlink,omitempty"` + ConnCount int32 `protobuf:"varint,3,opt,name=conn_count,json=connCount,proto3" json:"conn_count,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetUserStatsResponse) Reset() { + *x = GetUserStatsResponse{} + mi := &file_app_connectiontracker_command_command_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetUserStatsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetUserStatsResponse) ProtoMessage() {} + +func (x *GetUserStatsResponse) ProtoReflect() protoreflect.Message { + mi := &file_app_connectiontracker_command_command_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +func (*GetUserStatsResponse) Descriptor() ([]byte, []int) { + return file_app_connectiontracker_command_command_proto_rawDescGZIP(), []int{6} +} + +func (x *GetUserStatsResponse) GetUplink() int64 { + if x != nil { + return x.Uplink + } + return 0 +} + +func (x *GetUserStatsResponse) GetDownlink() int64 { + if x != nil { + return x.Downlink + } + return 0 +} + +func (x *GetUserStatsResponse) GetConnCount() int32 { + if x != nil { + return x.ConnCount + } + return 0 +} + +type StreamConnectionsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StreamConnectionsRequest) Reset() { + *x = StreamConnectionsRequest{} + mi := &file_app_connectiontracker_command_command_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StreamConnectionsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StreamConnectionsRequest) ProtoMessage() {} + +func (x *StreamConnectionsRequest) ProtoReflect() protoreflect.Message { + mi := &file_app_connectiontracker_command_command_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +func (*StreamConnectionsRequest) Descriptor() ([]byte, []int) { + return file_app_connectiontracker_command_command_proto_rawDescGZIP(), []int{7} +} + +type ConnectionUpdate struct { + state protoimpl.MessageState `protogen:"open.v1"` + Event ConnEventType `protobuf:"varint,1,opt,name=event,proto3,enum=xray.app.connectiontracker.command.ConnEventType" json:"event,omitempty"` + Conn *ConnInfo `protobuf:"bytes,2,opt,name=conn,proto3" json:"conn,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ConnectionUpdate) Reset() { + *x = ConnectionUpdate{} + mi := &file_app_connectiontracker_command_command_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ConnectionUpdate) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ConnectionUpdate) ProtoMessage() {} + +func (x *ConnectionUpdate) ProtoReflect() protoreflect.Message { + mi := &file_app_connectiontracker_command_command_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +func (*ConnectionUpdate) Descriptor() ([]byte, []int) { + return file_app_connectiontracker_command_command_proto_rawDescGZIP(), []int{8} +} + +func (x *ConnectionUpdate) GetEvent() ConnEventType { + if x != nil { + return x.Event + } + return ConnEventType_CONNECTED +} + +func (x *ConnectionUpdate) GetConn() *ConnInfo { + if x != nil { + return x.Conn + } + return nil +} + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_app_connectiontracker_command_command_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_app_connectiontracker_command_command_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +func (*Config) Descriptor() ([]byte, []int) { + return file_app_connectiontracker_command_command_proto_rawDescGZIP(), []int{9} +} + +var File_app_connectiontracker_command_command_proto protoreflect.FileDescriptor + +// rawDesc is the binary proto FileDescriptorProto for this file. +// Verified manually against the proto source; do not edit by hand. +const file_app_connectiontracker_command_command_proto_rawDesc = "" + + "\n" + + "+app/connectiontracker/command/command.proto\x12\"xray.app.connectiontracker.command*0\n" + + "\rConnEventType\x12\r\n" + + "\tCONNECTED\x10\x00\x12\x10\n" + + "\fDISCONNECTED\x10\x01\"\xe5\x01\n" + + "\bConnInfo\x12\x0e\n" + + "\x02id\x18\x01 \x01(\rR\x02id\x12\x14\n" + + "\x05email\x18\x02 \x01(\tR\x05email\x12\x1f\n" + + "\vinbound_tag\x18\x03 \x01(\tR\n" + + "inboundTag\x12\x1a\n" + + "\bprotocol\x18\x04 \x01(\tR\bprotocol\x12\x1d\n" + + "\n" + + "start_time\x18\x05 \x01(\x03R\tstartTime\x12#\n" + + "\rlast_activity\x18\x06 \x01(\x03R\x0clastActivity\x12\x16\n" + + "\x06uplink\x18\a \x01(\x03R\x06uplink\x12\x1a\n" + + "\bdownlink\x18\b \x01(\x03R\bdownlink\"\x18\n" + + "\x16ListConnectionsRequest\"i\n" + + "\x17ListConnectionsResponse\x12N\n" + + "\vconnections\x18\x01 \x03(\v2,.xray.app.connectiontracker.command.ConnInfoR\vconnections\"(\n" + + "\x16CloseConnectionRequest\x12\x0e\n" + + "\x02id\x18\x01 \x01(\rR\x02id\"/\n" + + "\x17CloseConnectionResponse\x12\x14\n" + + "\x05found\x18\x01 \x01(\bR\x05found\"+\n" + + "\x13GetUserStatsRequest\x12\x14\n" + + "\x05email\x18\x01 \x01(\tR\x05email\"i\n" + + "\x14GetUserStatsResponse\x12\x16\n" + + "\x06uplink\x18\x01 \x01(\x03R\x06uplink\x12\x1a\n" + + "\bdownlink\x18\x02 \x01(\x03R\bdownlink\x12\x1d\n" + + "\n" + + "conn_count\x18\x03 \x01(\x05R\tconnCount\"\x1a\n" + + "\x18StreamConnectionsRequest\"\x9d\x01\n" + + "\x10ConnectionUpdate\x12G\n" + + "\x05event\x18\x01 \x01(\x0e21.xray.app.connectiontracker.command.ConnEventTypeR\x05event\x12@\n" + + "\x04conn\x18\x02 \x01(\v2,.xray.app.connectiontracker.command.ConnInfoR\x04conn\"\b\n" + + "\x06Config2\xc6\x04\n" + + "\x12ConnTrackerService\x12\x8c\x01\n" + + "\x0fListConnections\x12:.xray.app.connectiontracker.command.ListConnectionsRequest\x1a;.xray.app.connectiontracker.command.ListConnectionsResponse\"\x00\x12\x8c\x01\n" + + "\x0fCloseConnection\x12:.xray.app.connectiontracker.command.CloseConnectionRequest\x1a;.xray.app.connectiontracker.command.CloseConnectionResponse\"\x00\x12\x83\x01\n" + + "\fGetUserStats\x127.xray.app.connectiontracker.command.GetUserStatsRequest\x1a8.xray.app.connectiontracker.command.GetUserStatsResponse\"\x00\x12\x8b\x01\n" + + "\x11StreamConnections\x12<.xray.app.connectiontracker.command.StreamConnectionsRequest\x1a4.xray.app.connectiontracker.command.ConnectionUpdate\"\x000\x01" + + "B9Z7github.com/xtls/xray-core/app/connectiontracker/commandb\x06proto3" + +var ( + file_app_connectiontracker_command_command_proto_rawDescOnce sync.Once + file_app_connectiontracker_command_command_proto_rawDescData []byte +) + +func file_app_connectiontracker_command_command_proto_rawDescGZIP() []byte { + file_app_connectiontracker_command_command_proto_rawDescOnce.Do(func() { + file_app_connectiontracker_command_command_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_app_connectiontracker_command_command_proto_rawDesc), len(file_app_connectiontracker_command_command_proto_rawDesc))) + }) + return file_app_connectiontracker_command_command_proto_rawDescData +} + +var file_app_connectiontracker_command_command_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_app_connectiontracker_command_command_proto_msgTypes = make([]protoimpl.MessageInfo, 10) +var file_app_connectiontracker_command_command_proto_goTypes = []any{ + (ConnEventType)(0), // 0: xray.app.connectiontracker.command.ConnEventType + (*ConnInfo)(nil), // 1: xray.app.connectiontracker.command.ConnInfo + (*ListConnectionsRequest)(nil), // 2: xray.app.connectiontracker.command.ListConnectionsRequest + (*ListConnectionsResponse)(nil), // 3: xray.app.connectiontracker.command.ListConnectionsResponse + (*CloseConnectionRequest)(nil), // 4: xray.app.connectiontracker.command.CloseConnectionRequest + (*CloseConnectionResponse)(nil), // 5: xray.app.connectiontracker.command.CloseConnectionResponse + (*GetUserStatsRequest)(nil), // 6: xray.app.connectiontracker.command.GetUserStatsRequest + (*GetUserStatsResponse)(nil), // 7: xray.app.connectiontracker.command.GetUserStatsResponse + (*StreamConnectionsRequest)(nil), // 8: xray.app.connectiontracker.command.StreamConnectionsRequest + (*ConnectionUpdate)(nil), // 9: xray.app.connectiontracker.command.ConnectionUpdate + (*Config)(nil), // 10: xray.app.connectiontracker.command.Config +} +var file_app_connectiontracker_command_command_proto_depIdxs = []int32{ + 1, // 0: xray.app.connectiontracker.command.ListConnectionsResponse.connections:type_name -> xray.app.connectiontracker.command.ConnInfo + 0, // 1: xray.app.connectiontracker.command.ConnectionUpdate.event:type_name -> xray.app.connectiontracker.command.ConnEventType + 1, // 2: xray.app.connectiontracker.command.ConnectionUpdate.conn:type_name -> xray.app.connectiontracker.command.ConnInfo + 2, // 3: xray.app.connectiontracker.command.ConnTrackerService.ListConnections:input_type -> xray.app.connectiontracker.command.ListConnectionsRequest + 4, // 4: xray.app.connectiontracker.command.ConnTrackerService.CloseConnection:input_type -> xray.app.connectiontracker.command.CloseConnectionRequest + 6, // 5: xray.app.connectiontracker.command.ConnTrackerService.GetUserStats:input_type -> xray.app.connectiontracker.command.GetUserStatsRequest + 8, // 6: xray.app.connectiontracker.command.ConnTrackerService.StreamConnections:input_type -> xray.app.connectiontracker.command.StreamConnectionsRequest + 3, // 7: xray.app.connectiontracker.command.ConnTrackerService.ListConnections:output_type -> xray.app.connectiontracker.command.ListConnectionsResponse + 5, // 8: xray.app.connectiontracker.command.ConnTrackerService.CloseConnection:output_type -> xray.app.connectiontracker.command.CloseConnectionResponse + 7, // 9: xray.app.connectiontracker.command.ConnTrackerService.GetUserStats:output_type -> xray.app.connectiontracker.command.GetUserStatsResponse + 9, // 10: xray.app.connectiontracker.command.ConnTrackerService.StreamConnections:output_type -> xray.app.connectiontracker.command.ConnectionUpdate + 7, // [7:11] is the sub-list for method output_type + 3, // [3:7] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name +} + +func init() { file_app_connectiontracker_command_command_proto_init() } +func file_app_connectiontracker_command_command_proto_init() { + if File_app_connectiontracker_command_command_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_app_connectiontracker_command_command_proto_rawDesc), len(file_app_connectiontracker_command_command_proto_rawDesc)), + NumEnums: 1, + NumMessages: 10, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_app_connectiontracker_command_command_proto_goTypes, + DependencyIndexes: file_app_connectiontracker_command_command_proto_depIdxs, + EnumInfos: file_app_connectiontracker_command_command_proto_enumTypes, + MessageInfos: file_app_connectiontracker_command_command_proto_msgTypes, + }.Build() + File_app_connectiontracker_command_command_proto = out.File + file_app_connectiontracker_command_command_proto_goTypes = nil + file_app_connectiontracker_command_command_proto_depIdxs = nil +} diff --git a/app/connectiontracker/command/command.proto b/app/connectiontracker/command/command.proto new file mode 100644 index 000000000000..5855da2b40e9 --- /dev/null +++ b/app/connectiontracker/command/command.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package xray.app.connectiontracker.command; + +option go_package = "github.com/xtls/xray-core/app/connectiontracker/command"; + +enum ConnEventType { + CONNECTED = 0; + DISCONNECTED = 1; +} + +message ConnInfo { + uint32 id = 1; + string email = 2; + string inbound_tag = 3; + string protocol = 4; + int64 start_time = 5; + int64 last_activity = 6; + int64 uplink = 7; + int64 downlink = 8; +} + +message ListConnectionsRequest {} + +message ListConnectionsResponse { + repeated ConnInfo connections = 1; +} + +message CloseConnectionRequest { + uint32 id = 1; +} + +message CloseConnectionResponse { + bool found = 1; +} + +message GetUserStatsRequest { + string email = 1; +} + +message GetUserStatsResponse { + int64 uplink = 1; + int64 downlink = 2; + int32 conn_count = 3; +} + +message StreamConnectionsRequest {} + +message ConnectionUpdate { + ConnEventType event = 1; + ConnInfo conn = 2; +} + +// Config is the configuration for the connection-tracker gRPC command service. +// Add it to the commander's services list to enable the API. +message Config {} + +service ConnTrackerService { + rpc ListConnections(ListConnectionsRequest) returns (ListConnectionsResponse) {} + rpc CloseConnection(CloseConnectionRequest) returns (CloseConnectionResponse) {} + rpc GetUserStats(GetUserStatsRequest) returns (GetUserStatsResponse) {} + rpc StreamConnections(StreamConnectionsRequest) returns (stream ConnectionUpdate) {} +} diff --git a/app/connectiontracker/command/command_grpc.pb.go b/app/connectiontracker/command/command_grpc.pb.go new file mode 100644 index 000000000000..45317a368e58 --- /dev/null +++ b/app/connectiontracker/command/command_grpc.pb.go @@ -0,0 +1,224 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.0 +// - protoc v6.33.5 +// source: app/connectiontracker/command/command.proto + +package command + +import ( + context "context" + + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +const _ = grpc.SupportPackageIsVersion9 + +const ( + ConnTrackerService_ListConnections_FullMethodName = "/xray.app.connectiontracker.command.ConnTrackerService/ListConnections" + ConnTrackerService_CloseConnection_FullMethodName = "/xray.app.connectiontracker.command.ConnTrackerService/CloseConnection" + ConnTrackerService_GetUserStats_FullMethodName = "/xray.app.connectiontracker.command.ConnTrackerService/GetUserStats" + ConnTrackerService_StreamConnections_FullMethodName = "/xray.app.connectiontracker.command.ConnTrackerService/StreamConnections" +) + +// ConnTrackerServiceClient is the client API for ConnTrackerService service. +type ConnTrackerServiceClient interface { + ListConnections(ctx context.Context, in *ListConnectionsRequest, opts ...grpc.CallOption) (*ListConnectionsResponse, error) + CloseConnection(ctx context.Context, in *CloseConnectionRequest, opts ...grpc.CallOption) (*CloseConnectionResponse, error) + GetUserStats(ctx context.Context, in *GetUserStatsRequest, opts ...grpc.CallOption) (*GetUserStatsResponse, error) + StreamConnections(ctx context.Context, in *StreamConnectionsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ConnectionUpdate], error) +} + +type connTrackerServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewConnTrackerServiceClient(cc grpc.ClientConnInterface) ConnTrackerServiceClient { + return &connTrackerServiceClient{cc} +} + +func (c *connTrackerServiceClient) ListConnections(ctx context.Context, in *ListConnectionsRequest, opts ...grpc.CallOption) (*ListConnectionsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ListConnectionsResponse) + err := c.cc.Invoke(ctx, ConnTrackerService_ListConnections_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *connTrackerServiceClient) CloseConnection(ctx context.Context, in *CloseConnectionRequest, opts ...grpc.CallOption) (*CloseConnectionResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(CloseConnectionResponse) + err := c.cc.Invoke(ctx, ConnTrackerService_CloseConnection_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *connTrackerServiceClient) GetUserStats(ctx context.Context, in *GetUserStatsRequest, opts ...grpc.CallOption) (*GetUserStatsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetUserStatsResponse) + err := c.cc.Invoke(ctx, ConnTrackerService_GetUserStats_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *connTrackerServiceClient) StreamConnections(ctx context.Context, in *StreamConnectionsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ConnectionUpdate], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &ConnTrackerService_ServiceDesc.Streams[0], ConnTrackerService_StreamConnections_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[StreamConnectionsRequest, ConnectionUpdate]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ConnTrackerService_StreamConnectionsClient = grpc.ServerStreamingClient[ConnectionUpdate] + +// ConnTrackerServiceServer is the server API for ConnTrackerService service. +// All implementations must embed UnimplementedConnTrackerServiceServer +// for forward compatibility. +type ConnTrackerServiceServer interface { + ListConnections(context.Context, *ListConnectionsRequest) (*ListConnectionsResponse, error) + CloseConnection(context.Context, *CloseConnectionRequest) (*CloseConnectionResponse, error) + GetUserStats(context.Context, *GetUserStatsRequest) (*GetUserStatsResponse, error) + StreamConnections(*StreamConnectionsRequest, grpc.ServerStreamingServer[ConnectionUpdate]) error + mustEmbedUnimplementedConnTrackerServiceServer() +} + +// UnimplementedConnTrackerServiceServer must be embedded to have +// forward compatible implementations. +type UnimplementedConnTrackerServiceServer struct{} + +func (UnimplementedConnTrackerServiceServer) ListConnections(context.Context, *ListConnectionsRequest) (*ListConnectionsResponse, error) { + return nil, status.Error(codes.Unimplemented, "method ListConnections not implemented") +} +func (UnimplementedConnTrackerServiceServer) CloseConnection(context.Context, *CloseConnectionRequest) (*CloseConnectionResponse, error) { + return nil, status.Error(codes.Unimplemented, "method CloseConnection not implemented") +} +func (UnimplementedConnTrackerServiceServer) GetUserStats(context.Context, *GetUserStatsRequest) (*GetUserStatsResponse, error) { + return nil, status.Error(codes.Unimplemented, "method GetUserStats not implemented") +} +func (UnimplementedConnTrackerServiceServer) StreamConnections(*StreamConnectionsRequest, grpc.ServerStreamingServer[ConnectionUpdate]) error { + return status.Error(codes.Unimplemented, "method StreamConnections not implemented") +} +func (UnimplementedConnTrackerServiceServer) mustEmbedUnimplementedConnTrackerServiceServer() {} +func (UnimplementedConnTrackerServiceServer) testEmbeddedByValue() {} + +// UnsafeConnTrackerServiceServer may be embedded to opt out of forward compatibility for this service. +type UnsafeConnTrackerServiceServer interface { + mustEmbedUnimplementedConnTrackerServiceServer() +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ConnTrackerService_StreamConnectionsServer = grpc.ServerStreamingServer[ConnectionUpdate] + +func RegisterConnTrackerServiceServer(s grpc.ServiceRegistrar, srv ConnTrackerServiceServer) { + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&ConnTrackerService_ServiceDesc, srv) +} + +func _ConnTrackerService_ListConnections_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListConnectionsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ConnTrackerServiceServer).ListConnections(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ConnTrackerService_ListConnections_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ConnTrackerServiceServer).ListConnections(ctx, req.(*ListConnectionsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ConnTrackerService_CloseConnection_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CloseConnectionRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ConnTrackerServiceServer).CloseConnection(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ConnTrackerService_CloseConnection_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ConnTrackerServiceServer).CloseConnection(ctx, req.(*CloseConnectionRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ConnTrackerService_GetUserStats_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetUserStatsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ConnTrackerServiceServer).GetUserStats(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ConnTrackerService_GetUserStats_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ConnTrackerServiceServer).GetUserStats(ctx, req.(*GetUserStatsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ConnTrackerService_StreamConnections_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(StreamConnectionsRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(ConnTrackerServiceServer).StreamConnections(m, &grpc.GenericServerStream[StreamConnectionsRequest, ConnectionUpdate]{ServerStream: stream}) +} + +// ConnTrackerService_ServiceDesc is the grpc.ServiceDesc for ConnTrackerService service. +var ConnTrackerService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "xray.app.connectiontracker.command.ConnTrackerService", + HandlerType: (*ConnTrackerServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "ListConnections", + Handler: _ConnTrackerService_ListConnections_Handler, + }, + { + MethodName: "CloseConnection", + Handler: _ConnTrackerService_CloseConnection_Handler, + }, + { + MethodName: "GetUserStats", + Handler: _ConnTrackerService_GetUserStats_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "StreamConnections", + Handler: _ConnTrackerService_StreamConnections_Handler, + ServerStreams: true, + }, + }, + Metadata: "app/connectiontracker/command/command.proto", +} diff --git a/app/connectiontracker/service.go b/app/connectiontracker/service.go new file mode 100644 index 000000000000..478dfd54fa82 --- /dev/null +++ b/app/connectiontracker/service.go @@ -0,0 +1,58 @@ +package connectiontracker + +import ( + "context" + + "github.com/xtls/xray-core/common" + xrayfeatures "github.com/xtls/xray-core/features" +) + +// Feature exposes the shared connection tracker manager through Xray's +// feature-resolution system. +type Feature interface { + xrayfeatures.Feature + Manager() *Manager +} + +func FeatureType() interface{} { + return (*Feature)(nil) +} + +// Config is the config placeholder for explicitly constructing the tracker +// feature. +type Config struct{} + +// Service owns the singleton manager for one Xray instance. +type Service struct { + manager *Manager +} + +func NewService() *Service { + return &Service{ + manager: NewManager(), + } +} + +func (*Service) Type() interface{} { + return FeatureType() +} + +func (s *Service) Start() error { + return nil +} + +func (s *Service) Close() error { + return s.manager.Close() +} + +func (s *Service) Manager() *Manager { + return s.manager +} + +var _ Feature = (*Service)(nil) + +func init() { + common.Must(common.RegisterConfig((*Config)(nil), func(context.Context, interface{}) (interface{}, error) { + return NewService(), nil + })) +} diff --git a/app/connectiontracker/tracker.go b/app/connectiontracker/tracker.go new file mode 100644 index 000000000000..e66535a08933 --- /dev/null +++ b/app/connectiontracker/tracker.go @@ -0,0 +1,450 @@ +// Package connectiontracker provides a thread-safe registry of active proxy +// connections. It enables forced per-user disconnection and exposes real-time +// per-connection metadata and traffic statistics for API consumers. +package connectiontracker + +import ( + "context" + "io" + "sync" + "sync/atomic" + "time" + + B "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/xtls/xray-core/transport/internet/stat" +) + +// ConnEntry holds metadata and traffic state for a single tracked connection. +type ConnEntry struct { + Email string + InboundTag string + Protocol string + Cancel context.CancelFunc + StartTime time.Time + lastActivity int64 // atomic, Unix nanosecond timestamp + uplink int64 // atomic, bytes received from client + downlink int64 // atomic, bytes sent to client + closerMu sync.Mutex + closer io.Closer +} + +// ConnectionInfo is a read-only snapshot of an active connection's state. +type ConnectionInfo struct { + ID uint32 + Email string + InboundTag string + Protocol string + StartTime time.Time + LastActivity time.Time + Uplink int64 + Downlink int64 +} + +// Manager holds the shared connection registry and subscription fan-out for +// a single Xray instance. +type Manager struct { + globalNext uint32 + + globalMu sync.Mutex + trackers []*Tracker + + subMu sync.Mutex + subscribers []chan WatchEvent +} + +// Tracker tracks active connections per user, enabling forced disconnection +// and real-time connection inspection. +type Tracker struct { + manager *Manager + + mu sync.Mutex + conns map[string]map[uint32]*ConnEntry // [email][id] -> entry + byID map[uint32]*ConnEntry // flat index for O(1) lookup by ID +} + +// WatchEvent is delivered to subscribers whenever a connection opens or closes. +type WatchEvent struct { + Connected bool // true = opened, false = closed + Info ConnectionInfo +} + +// NewManager creates an empty tracker manager. +func NewManager() *Manager { + return &Manager{} +} + +// Close clears manager-owned registries. +func (m *Manager) Close() error { + m.globalMu.Lock() + m.trackers = nil + m.globalMu.Unlock() + + m.subMu.Lock() + m.subscribers = nil + m.subMu.Unlock() + + return nil +} + +func (m *Manager) snapshotTrackers() []*Tracker { + m.globalMu.Lock() + trackers := make([]*Tracker, len(m.trackers)) + copy(trackers, m.trackers) + m.globalMu.Unlock() + return trackers +} + +// Subscribe returns a channel that receives WatchEvents. Call Unsubscribe when +// done to avoid a goroutine / channel leak. +func (m *Manager) Subscribe() chan WatchEvent { + ch := make(chan WatchEvent, 64) + m.subMu.Lock() + m.subscribers = append(m.subscribers, ch) + m.subMu.Unlock() + return ch +} + +// Unsubscribe removes a channel returned by Subscribe. +func (m *Manager) Unsubscribe(ch chan WatchEvent) { + m.subMu.Lock() + defer m.subMu.Unlock() + + for i, s := range m.subscribers { + if s == ch { + m.subscribers[i] = m.subscribers[len(m.subscribers)-1] + m.subscribers = m.subscribers[:len(m.subscribers)-1] + return + } + } +} + +func (m *Manager) emit(ev WatchEvent) { + m.subMu.Lock() + subs := make([]chan WatchEvent, len(m.subscribers)) + copy(subs, m.subscribers) + m.subMu.Unlock() + for _, ch := range subs { + select { + case ch <- ev: + default: // drop if subscriber is too slow + } + } +} + +// NewTracker creates a new, empty Tracker and registers it in the manager so +// that ListAllConnections and CloseGlobalConn can see its connections. +func (m *Manager) NewTracker() *Tracker { + t := &Tracker{ + manager: m, + conns: make(map[string]map[uint32]*ConnEntry), + byID: make(map[uint32]*ConnEntry), + } + m.globalMu.Lock() + m.trackers = append(m.trackers, t) + m.globalMu.Unlock() + return t +} + +// New creates a new Tracker with its own Manager. +func New() *Tracker { + return NewManager().NewTracker() +} + +func disconnectInfo(id uint32, entry *ConnEntry) ConnectionInfo { + return ConnectionInfo{ + ID: id, + Email: entry.Email, + InboundTag: entry.InboundTag, + Protocol: entry.Protocol, + StartTime: entry.StartTime, + LastActivity: time.Unix(0, atomic.LoadInt64(&entry.lastActivity)), + Uplink: atomic.LoadInt64(&entry.uplink), + Downlink: atomic.LoadInt64(&entry.downlink), + } +} + +// ListAllConnections returns a snapshot of every active connection across all +// Tracker instances that were created by NewTracker. +func (m *Manager) ListAllConnections() []ConnectionInfo { + ts := m.snapshotTrackers() + var all []ConnectionInfo + for _, t := range ts { + all = append(all, t.ListConnections()...) + } + return all +} + +// GetUserStats returns the aggregate uplink bytes, downlink bytes, and active +// connection count for email across all registered Trackers. +func (m *Manager) GetUserStats(email string) (uplink, downlink int64, connCount int32) { + ts := m.snapshotTrackers() + for _, t := range ts { + t.mu.Lock() + for _, e := range t.conns[email] { + uplink += atomic.LoadInt64(&e.uplink) + downlink += atomic.LoadInt64(&e.downlink) + connCount++ + } + t.mu.Unlock() + } + return +} + +// CloseGlobalConn closes the connection with the given ID in whichever Tracker +// owns it. Returns true if the connection was found and cancelled. +func (m *Manager) CloseGlobalConn(id uint32) bool { + ts := m.snapshotTrackers() + for _, t := range ts { + if t.CloseConn(id) { + return true + } + } + return false +} + +// Register records a connection's cancel function under email and returns its +// ID. Use RegisterWithMeta for richer per-connection tracking. +func (t *Tracker) Register(email string, cancel context.CancelFunc) uint32 { + id, _ := t.RegisterWithMeta(email, cancel, "", "") + return id +} + +// RegisterWithMeta records a connection with full metadata and returns the +// connection ID and a *ConnEntry whose traffic counters can be updated by +// passing it to WrapConn. +func (t *Tracker) RegisterWithMeta(email string, cancel context.CancelFunc, inboundTag, protocol string) (uint32, *ConnEntry) { + now := time.Now() + entry := &ConnEntry{ + Email: email, + InboundTag: inboundTag, + Protocol: protocol, + Cancel: cancel, + StartTime: now, + } + atomic.StoreInt64(&entry.lastActivity, now.UnixNano()) + id := atomic.AddUint32(&t.manager.globalNext, 1) + t.mu.Lock() + if t.conns[email] == nil { + t.conns[email] = make(map[uint32]*ConnEntry) + } + t.conns[email][id] = entry + t.byID[id] = entry + t.mu.Unlock() + t.manager.emit(WatchEvent{Connected: true, Info: ConnectionInfo{ + ID: id, + Email: email, + InboundTag: inboundTag, + Protocol: protocol, + StartTime: now, + LastActivity: now, + }}) + return id, entry +} + +// Unregister removes a connection from the tracker when it closes naturally. +func (t *Tracker) Unregister(email string, id uint32) { + t.mu.Lock() + entry := t.byID[id] + delete(t.byID, id) + if m := t.conns[email]; m != nil { + delete(m, id) + if len(m) == 0 { + delete(t.conns, email) + } + } + t.mu.Unlock() + if entry != nil { + t.manager.emit(WatchEvent{Connected: false, Info: disconnectInfo(id, entry)}) + } +} + +// CancelAll cancels every active connection belonging to email. +func (t *Tracker) CancelAll(email string) { + t.mu.Lock() + entries := t.conns[email] + delete(t.conns, email) + for id := range entries { + delete(t.byID, id) + } + t.mu.Unlock() + + for id, entry := range entries { + t.manager.emit(WatchEvent{ + Connected: false, + Info: disconnectInfo(id, entry), + }) + entry.cancelAndClose() + } +} + +// CloseConn cancels the connection identified by id. +// Returns true if the connection was found and cancelled. +func (t *Tracker) CloseConn(id uint32) bool { + t.mu.Lock() + entry, ok := t.byID[id] + if ok { + delete(t.byID, id) + if m := t.conns[entry.Email]; m != nil { + delete(m, id) + if len(m) == 0 { + delete(t.conns, entry.Email) + } + } + } + t.mu.Unlock() + + if ok { + t.manager.emit(WatchEvent{ + Connected: false, + Info: disconnectInfo(id, entry), + }) + entry.cancelAndClose() + } + return ok +} + +// GetConnCount returns the number of active connections for email. +func (t *Tracker) GetConnCount(email string) int { + t.mu.Lock() + n := len(t.conns[email]) + t.mu.Unlock() + return n +} + +// ListConnections returns a snapshot of all currently active connections. +func (t *Tracker) ListConnections() []ConnectionInfo { + t.mu.Lock() + result := make([]ConnectionInfo, 0, len(t.byID)) + for id, entry := range t.byID { + info := disconnectInfo(id, entry) + info.Uplink = atomic.LoadInt64(&entry.uplink) + info.Downlink = atomic.LoadInt64(&entry.downlink) + result = append(result, info) + } + t.mu.Unlock() + return result +} + +// TrackedConn wraps a stat.Connection and records per-connection traffic +// counters into the associated ConnEntry. Obtain one via WrapConn. +type TrackedConn struct { + stat.Connection + entry *ConnEntry +} + +func (c *TrackedConn) updateActivity(uplink, downlink int64) { + now := time.Now().UnixNano() + if uplink > 0 { + atomic.AddInt64(&c.entry.uplink, uplink) + } + if downlink > 0 { + atomic.AddInt64(&c.entry.downlink, downlink) + } + if uplink > 0 || downlink > 0 { + atomic.StoreInt64(&c.entry.lastActivity, now) + } +} + +func (c *TrackedConn) Read(b []byte) (int, error) { + n, err := c.Connection.Read(b) + if n > 0 { + c.updateActivity(int64(n), 0) + } + return n, err +} + +func (c *TrackedConn) Write(b []byte) (int, error) { + n, err := c.Connection.Write(b) + if n > 0 { + c.updateActivity(0, int64(n)) + } + return n, err +} + +// WrapConn wraps conn so that every Read and Write updates the traffic +// counters in entry. Call after RegisterWithMeta to enable byte-level tracking. +func WrapConn(conn stat.Connection, entry *ConnEntry) stat.Connection { + if entry != nil { + entry.setCloser(conn) + } + return &TrackedConn{Connection: conn, entry: entry} +} + +// TrackedPacketConn wraps an N.PacketConn (UDP) and records per-connection +// traffic counters into the associated ConnEntry. +type TrackedPacketConn struct { + N.PacketConn + entry *ConnEntry +} + +func (c *TrackedPacketConn) updateActivity(uplink, downlink int64) { + now := time.Now().UnixNano() + if uplink > 0 { + atomic.AddInt64(&c.entry.uplink, uplink) + } + if downlink > 0 { + atomic.AddInt64(&c.entry.downlink, downlink) + } + atomic.StoreInt64(&c.entry.lastActivity, now) +} + +func (c *TrackedPacketConn) ReadPacket(buffer *B.Buffer) (M.Socksaddr, error) { + addr, err := c.PacketConn.ReadPacket(buffer) + if err == nil && buffer.Len() > 0 { + c.updateActivity(int64(buffer.Len()), 0) + } + return addr, err +} + +func (c *TrackedPacketConn) WritePacket(buffer *B.Buffer, destination M.Socksaddr) error { + n := buffer.Len() + err := c.PacketConn.WritePacket(buffer, destination) + if err == nil && n > 0 { + c.updateActivity(0, int64(n)) + } + return err +} + +// WrapPacketConn wraps a UDP PacketConn so that every ReadPacket and WritePacket +// updates the traffic counters in entry. Call after RegisterWithMeta to enable +// byte-level tracking for UDP connections. +func WrapPacketConn(conn N.PacketConn, entry *ConnEntry) N.PacketConn { + if entry != nil { + entry.setCloser(conn) + } + return &TrackedPacketConn{PacketConn: conn, entry: entry} +} + +func (e *ConnEntry) setCloser(c io.Closer) { + if e == nil || c == nil { + return + } + e.closerMu.Lock() + e.closer = c + e.closerMu.Unlock() +} + +func (e *ConnEntry) closeCloser() { + if e == nil { + return + } + e.closerMu.Lock() + closer := e.closer + e.closerMu.Unlock() + if closer == nil { + return + } + _ = closer.Close() +} + +func (e *ConnEntry) cancelAndClose() { + if e == nil { + return + } + if e.Cancel != nil { + e.Cancel() + } + e.closeCloser() +} diff --git a/app/connectiontracker/tracker_test.go b/app/connectiontracker/tracker_test.go new file mode 100644 index 000000000000..9a491662841c --- /dev/null +++ b/app/connectiontracker/tracker_test.go @@ -0,0 +1,604 @@ +package connectiontracker_test + +import ( + "net" + "sync" + "sync/atomic" + "testing" + "time" + + B "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + "github.com/xtls/xray-core/app/connectiontracker" +) + +func TestCancelAll(t *testing.T) { + tracker := connectiontracker.New() + + var cancelCount int32 + makeCancel := func() func() { + return func() { atomic.AddInt32(&cancelCount, 1) } + } + + tracker.Register("user@example.com", makeCancel()) + tracker.Register("user@example.com", makeCancel()) + tracker.Register("other@example.com", makeCancel()) + + tracker.CancelAll("user@example.com") + + if got := atomic.LoadInt32(&cancelCount); got != 2 { + t.Errorf("CancelAll: expected 2 cancels called, got %d", got) + } +} + +func TestCancelAllClosesTrackedConnections(t *testing.T) { + tracker := connectiontracker.New() + + _, firstEntry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + _, secondEntry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + + var firstClosed int32 + var secondClosed int32 + + connectiontracker.WrapConn(&fakeConn{closeCount: &firstClosed}, firstEntry) + connectiontracker.WrapConn(&fakeConn{closeCount: &secondClosed}, secondEntry) + + tracker.CancelAll("user@example.com") + + if atomic.LoadInt32(&firstClosed) != 1 { + t.Error("first tracked connection was not closed during CancelAll") + } + if atomic.LoadInt32(&secondClosed) != 1 { + t.Error("second tracked connection was not closed during CancelAll") + } +} + +func TestCancelAllDoesNotAffectOtherUsers(t *testing.T) { + tracker := connectiontracker.New() + + var otherCancelled int32 + tracker.Register("other@example.com", func() { atomic.AddInt32(&otherCancelled, 1) }) + + tracker.CancelAll("user@example.com") + + if atomic.LoadInt32(&otherCancelled) != 0 { + t.Error("CancelAll for user@example.com must not cancel other users") + } +} + +func TestUnregisterPreventsCancel(t *testing.T) { + tracker := connectiontracker.New() + + var cancelCalled int32 + id := tracker.Register("user@example.com", func() { atomic.AddInt32(&cancelCalled, 1) }) + + tracker.Unregister("user@example.com", id) + tracker.CancelAll("user@example.com") + + if atomic.LoadInt32(&cancelCalled) != 0 { + t.Error("cancel should not be called after Unregister") + } +} + +func TestUnregisterCleansEmptyBucket(t *testing.T) { + tracker := connectiontracker.New() + + id := tracker.Register("user@example.com", func() {}) + tracker.Unregister("user@example.com", id) + + // Second CancelAll must be a no-op. + tracker.CancelAll("user@example.com") +} + +func TestMultipleCancelAllNoPanic(t *testing.T) { + tracker := connectiontracker.New() + + tracker.Register("user@example.com", func() {}) + tracker.CancelAll("user@example.com") + tracker.CancelAll("user@example.com") +} + +func TestConcurrentAccess(t *testing.T) { + tracker := connectiontracker.New() + + const goroutines = 50 + const email = "concurrent@example.com" + + var wg sync.WaitGroup + var totalCancels int32 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + tracker.Register(email, func() { atomic.AddInt32(&totalCancels, 1) }) + }() + } + wg.Wait() + + tracker.CancelAll(email) + + if got := atomic.LoadInt32(&totalCancels); got != goroutines { + t.Errorf("concurrent: expected %d cancels, got %d", goroutines, got) + } +} + +func TestConcurrentRegisterAndCancel(t *testing.T) { + tracker := connectiontracker.New() + + const email = "race@example.com" + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(2) + go func() { + defer wg.Done() + tracker.Register(email, func() {}) + }() + go func() { + defer wg.Done() + tracker.CancelAll(email) + }() + } + wg.Wait() +} + +// --- RegisterWithMeta and extended API --- + +func TestRegisterWithMetaStoresMetadata(t *testing.T) { + tracker := connectiontracker.New() + + before := time.Now() + id, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "inbound-tag", "vless") + after := time.Now() + + if id == 0 { + t.Error("expected non-zero connection ID") + } + if entry == nil { + t.Fatal("expected non-nil ConnEntry") + } + if entry.Email != "user@example.com" { + t.Errorf("Email: got %q, want %q", entry.Email, "user@example.com") + } + if entry.InboundTag != "inbound-tag" { + t.Errorf("InboundTag: got %q, want %q", entry.InboundTag, "inbound-tag") + } + if entry.Protocol != "vless" { + t.Errorf("Protocol: got %q, want %q", entry.Protocol, "vless") + } + if entry.StartTime.Before(before) || entry.StartTime.After(after) { + t.Errorf("StartTime %v outside [%v, %v]", entry.StartTime, before, after) + } +} + +func TestListConnectionsReturnsAllActive(t *testing.T) { + tracker := connectiontracker.New() + + tracker.RegisterWithMeta("alice@example.com", func() {}, "tag-a", "vmess") + tracker.RegisterWithMeta("alice@example.com", func() {}, "tag-a", "vmess") + tracker.RegisterWithMeta("bob@example.com", func() {}, "tag-b", "trojan") + + conns := tracker.ListConnections() + if len(conns) != 3 { + t.Errorf("ListConnections: got %d, want 3", len(conns)) + } +} + +func TestManagerListAllConnectionsAggregatesTrackers(t *testing.T) { + manager := connectiontracker.NewManager() + first := manager.NewTracker() + second := manager.NewTracker() + + first.RegisterWithMeta("alice@example.com", func() {}, "tag-a", "vmess") + second.RegisterWithMeta("bob@example.com", func() {}, "tag-b", "trojan") + + conns := manager.ListAllConnections() + if len(conns) != 2 { + t.Fatalf("ListAllConnections: got %d, want 2", len(conns)) + } +} + +func TestManagerGetUserStatsAggregatesTrackers(t *testing.T) { + manager := connectiontracker.NewManager() + first := manager.NewTracker() + second := manager.NewTracker() + + _, firstEntry := first.RegisterWithMeta("user@example.com", func() {}, "", "") + firstConn := connectiontracker.WrapConn(&fakeConn{readData: make([]byte, 10)}, firstEntry) + if _, err := firstConn.Read(make([]byte, 10)); err != nil { + t.Fatalf("first read failed: %v", err) + } + if _, err := firstConn.Write(make([]byte, 20)); err != nil { + t.Fatalf("first write failed: %v", err) + } + + _, secondEntry := second.RegisterWithMeta("user@example.com", func() {}, "", "") + secondConn := connectiontracker.WrapConn(&fakeConn{readData: make([]byte, 30)}, secondEntry) + if _, err := secondConn.Read(make([]byte, 30)); err != nil { + t.Fatalf("second read failed: %v", err) + } + if _, err := secondConn.Write(make([]byte, 40)); err != nil { + t.Fatalf("second write failed: %v", err) + } + + uplink, downlink, connCount := manager.GetUserStats("user@example.com") + if uplink != 40 { + t.Fatalf("GetUserStats uplink: got %d, want 40", uplink) + } + if downlink != 60 { + t.Fatalf("GetUserStats downlink: got %d, want 60", downlink) + } + if connCount != 2 { + t.Fatalf("GetUserStats connCount: got %d, want 2", connCount) + } +} + +func TestManagerCloseGlobalConnAcrossTrackers(t *testing.T) { + manager := connectiontracker.NewManager() + first := manager.NewTracker() + second := manager.NewTracker() + + first.RegisterWithMeta("other@example.com", func() {}, "", "") + + var cancelled int32 + id, _ := second.RegisterWithMeta("user@example.com", func() { + atomic.AddInt32(&cancelled, 1) + }, "", "") + + if ok := manager.CloseGlobalConn(id); !ok { + t.Fatal("CloseGlobalConn: expected true for existing connection") + } + if atomic.LoadInt32(&cancelled) != 1 { + t.Fatal("CloseGlobalConn: cancel function was not called") + } +} + +func TestListConnectionsEmptyAfterCancelAll(t *testing.T) { + tracker := connectiontracker.New() + + tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + tracker.CancelAll("user@example.com") + + if conns := tracker.ListConnections(); len(conns) != 0 { + t.Errorf("expected 0 connections after CancelAll, got %d", len(conns)) + } +} + +func TestListConnectionsEmptyAfterUnregister(t *testing.T) { + tracker := connectiontracker.New() + + id, _ := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + tracker.Unregister("user@example.com", id) + + if conns := tracker.ListConnections(); len(conns) != 0 { + t.Errorf("expected 0 connections after Unregister, got %d", len(conns)) + } +} + +func TestCloseConnCancelsAndRemoves(t *testing.T) { + tracker := connectiontracker.New() + + var cancelled int32 + id, _ := tracker.RegisterWithMeta("user@example.com", func() { + atomic.AddInt32(&cancelled, 1) + }, "", "") + + if ok := tracker.CloseConn(id); !ok { + t.Error("CloseConn: expected true for existing connection") + } + if atomic.LoadInt32(&cancelled) != 1 { + t.Error("CloseConn: cancel function was not called") + } + if len(tracker.ListConnections()) != 0 { + t.Error("connection still present after CloseConn") + } +} + +func TestCloseConnClosesTrackedConnection(t *testing.T) { + tracker := connectiontracker.New() + + id, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + var closeCount int32 + connectiontracker.WrapConn(&fakeConn{closeCount: &closeCount}, entry) + + if ok := tracker.CloseConn(id); !ok { + t.Error("CloseConn: expected true for existing connection") + } + if atomic.LoadInt32(&closeCount) != 1 { + t.Error("CloseConn: tracked connection was not closed") + } +} + +func TestCloseConnUnknownIDReturnsFalse(t *testing.T) { + tracker := connectiontracker.New() + + if tracker.CloseConn(999) { + t.Error("CloseConn with unknown ID should return false") + } +} + +func TestCloseConnDoesNotAffectOtherUsers(t *testing.T) { + tracker := connectiontracker.New() + + var otherCancelled int32 + tracker.RegisterWithMeta("other@example.com", func() { + atomic.AddInt32(&otherCancelled, 1) + }, "", "") + + id, _ := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + tracker.CloseConn(id) + + if atomic.LoadInt32(&otherCancelled) != 0 { + t.Error("CloseConn must not cancel other users' connections") + } +} + +func TestGetConnCount(t *testing.T) { + tracker := connectiontracker.New() + + tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + tracker.RegisterWithMeta("other@example.com", func() {}, "", "") + + if n := tracker.GetConnCount("user@example.com"); n != 2 { + t.Errorf("GetConnCount: got %d, want 2", n) + } + if n := tracker.GetConnCount("other@example.com"); n != 1 { + t.Errorf("GetConnCount: got %d, want 1", n) + } + if n := tracker.GetConnCount("nobody@example.com"); n != 0 { + t.Errorf("GetConnCount for unknown: got %d, want 0", n) + } +} + +func TestGetConnCountDecreasesAfterUnregister(t *testing.T) { + tracker := connectiontracker.New() + + id, _ := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + tracker.Unregister("user@example.com", id) + + if n := tracker.GetConnCount("user@example.com"); n != 1 { + t.Errorf("GetConnCount after Unregister: got %d, want 1", n) + } +} + +func TestListConnectionsMetadataFields(t *testing.T) { + tracker := connectiontracker.New() + + tracker.RegisterWithMeta("user@example.com", func() {}, "my-tag", "trojan") + + conns := tracker.ListConnections() + if len(conns) != 1 { + t.Fatalf("expected 1 connection, got %d", len(conns)) + } + c := conns[0] + if c.Email != "user@example.com" { + t.Errorf("Email: %q", c.Email) + } + if c.InboundTag != "my-tag" { + t.Errorf("InboundTag: %q", c.InboundTag) + } + if c.Protocol != "trojan" { + t.Errorf("Protocol: %q", c.Protocol) + } + if c.ID == 0 { + t.Error("ID must be non-zero") + } +} + +// fakeConn is a minimal net.Conn for WrapConn tests. +type fakeConn struct { + net.Conn + readData []byte + readErr error + writeErr error + closeErr error + closeCount *int32 +} + +func (f *fakeConn) Read(b []byte) (int, error) { + n := copy(b, f.readData) + return n, f.readErr +} + +func (f *fakeConn) Write(b []byte) (int, error) { + return len(b), f.writeErr +} + +func (f *fakeConn) Close() error { + if f.closeCount != nil { + atomic.AddInt32(f.closeCount, 1) + } + return f.closeErr +} +func (f *fakeConn) LocalAddr() net.Addr { return nil } +func (f *fakeConn) RemoteAddr() net.Addr { return nil } +func (f *fakeConn) SetDeadline(_ time.Time) error { return nil } +func (f *fakeConn) SetReadDeadline(_ time.Time) error { return nil } +func (f *fakeConn) SetWriteDeadline(_ time.Time) error { return nil } + +func TestWrapConnCountsUplinkOnRead(t *testing.T) { + tracker := connectiontracker.New() + _, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + + fc := &fakeConn{readData: []byte("hello world")} + wrapped := connectiontracker.WrapConn(fc, entry) + + buf := make([]byte, 11) + if _, err := wrapped.Read(buf); err != nil { + t.Fatal(err) + } + + conns := tracker.ListConnections() + if len(conns) != 1 { + t.Fatalf("expected 1 connection") + } + if conns[0].Uplink != 11 { + t.Errorf("Uplink: got %d, want 11", conns[0].Uplink) + } + if conns[0].Downlink != 0 { + t.Errorf("Downlink should be 0, got %d", conns[0].Downlink) + } +} + +func TestWrapConnCountsDownlinkOnWrite(t *testing.T) { + tracker := connectiontracker.New() + _, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + + fc := &fakeConn{} + wrapped := connectiontracker.WrapConn(fc, entry) + + data := []byte("goodbye world") + if _, err := wrapped.Write(data); err != nil { + t.Fatal(err) + } + + conns := tracker.ListConnections() + if len(conns) != 1 { + t.Fatalf("expected 1 connection") + } + if conns[0].Downlink != int64(len(data)) { + t.Errorf("Downlink: got %d, want %d", conns[0].Downlink, len(data)) + } + if conns[0].Uplink != 0 { + t.Errorf("Uplink should be 0, got %d", conns[0].Uplink) + } +} + +func TestWrapConnUpdatesLastActivity(t *testing.T) { + tracker := connectiontracker.New() + _, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + + before := tracker.ListConnections()[0].LastActivity + + time.Sleep(time.Millisecond) + + fc := &fakeConn{readData: []byte("x")} + wrapped := connectiontracker.WrapConn(fc, entry) + buf := make([]byte, 1) + wrapped.Read(buf) //nolint:errcheck + + after := tracker.ListConnections()[0].LastActivity + if !after.After(before) { + t.Errorf("LastActivity not updated: before=%v after=%v", before, after) + } +} + +// fakePacketConn is a minimal N.PacketConn for WrapPacketConn tests. +type fakePacketConn struct { + readPacketData *B.Buffer + readPacketErr error + writePacketErr error +} + +func (f *fakePacketConn) ReadPacket(buffer *B.Buffer) (M.Socksaddr, error) { + if f.readPacketErr != nil { + return M.Socksaddr{}, f.readPacketErr + } + if f.readPacketData != nil { + buffer.Write(f.readPacketData.Bytes()) + } + return M.Socksaddr{}, nil +} + +func (f *fakePacketConn) WritePacket(buffer *B.Buffer, _ M.Socksaddr) error { + return f.writePacketErr +} + +func (f *fakePacketConn) Close() error { + return nil +} + +func (f *fakePacketConn) LocalAddr() net.Addr { + return nil +} + +func (f *fakePacketConn) SetDeadline(_ time.Time) error { + return nil +} + +func (f *fakePacketConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (f *fakePacketConn) SetWriteDeadline(_ time.Time) error { + return nil +} + +func TestWrapPacketConnCountsUplinkOnReadPacket(t *testing.T) { + tracker := connectiontracker.New() + _, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + + data := B.New() + data.Write([]byte("hello world")) + + fpc := &fakePacketConn{readPacketData: data} + wrapped := connectiontracker.WrapPacketConn(fpc, entry) + + buf := B.New() + defer buf.Release() + if _, err := wrapped.ReadPacket(buf); err != nil { + t.Fatal(err) + } + + conns := tracker.ListConnections() + if len(conns) != 1 { + t.Fatalf("expected 1 connection") + } + if conns[0].Uplink != 11 { + t.Errorf("Uplink: got %d, want 11", conns[0].Uplink) + } + if conns[0].Downlink != 0 { + t.Errorf("Downlink should be 0, got %d", conns[0].Downlink) + } +} + +func TestWrapPacketConnCountsDownlinkOnWritePacket(t *testing.T) { + tracker := connectiontracker.New() + _, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + + fpc := &fakePacketConn{} + wrapped := connectiontracker.WrapPacketConn(fpc, entry) + + buf := B.New() + buf.Write([]byte("goodbye world")) + if err := wrapped.WritePacket(buf, M.Socksaddr{}); err != nil { + t.Fatal(err) + } + + conns := tracker.ListConnections() + if len(conns) != 1 { + t.Fatalf("expected 1 connection") + } + if conns[0].Downlink != 13 { + t.Errorf("Downlink: got %d, want 13", conns[0].Downlink) + } + if conns[0].Uplink != 0 { + t.Errorf("Uplink should be 0, got %d", conns[0].Uplink) + } +} + +func TestWrapPacketConnUpdatesLastActivity(t *testing.T) { + tracker := connectiontracker.New() + _, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") + + before := tracker.ListConnections()[0].LastActivity + + time.Sleep(time.Millisecond) + + data := B.New() + data.Write([]byte("x")) + + fpc := &fakePacketConn{readPacketData: data} + wrapped := connectiontracker.WrapPacketConn(fpc, entry) + buf := B.New() + defer buf.Release() + wrapped.ReadPacket(buf) //nolint:errcheck + + after := tracker.ListConnections()[0].LastActivity + if !after.After(before) { + t.Errorf("LastActivity not updated: before=%v after=%v", before, after) + } +} diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index f6cfd76ebf6f..e22cc10a67da 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -498,7 +499,9 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. accessMessage.Detour = inTag + " >> " + tag } } - log.Record(accessMessage) + if connectiontracker.AccessRecordFromContext(ctx) == nil { + log.Record(accessMessage) + } } handler.Dispatch(ctx, link) diff --git a/common/log/access.go b/common/log/access.go index 204212dc175d..e952b3906614 100644 --- a/common/log/access.go +++ b/common/log/access.go @@ -2,6 +2,7 @@ package log import ( "context" + "strconv" "strings" "github.com/xtls/xray-core/common/serial" @@ -27,6 +28,9 @@ type AccessMessage struct { Reason interface{} Email string Detour string + + RequestBytes int64 + ResponseBytes int64 } func (m *AccessMessage) String() string { @@ -55,6 +59,13 @@ func (m *AccessMessage) String() string { builder.WriteString(m.Email) } + if m.RequestBytes > 0 || m.ResponseBytes > 0 { + builder.WriteString(" request_bytes: ") + builder.WriteString(strconv.FormatInt(m.RequestBytes, 10)) + builder.WriteString(" response_bytes: ") + builder.WriteString(strconv.FormatInt(m.ResponseBytes, 10)) + } + return builder.String() } diff --git a/core/xray.go b/core/xray.go index 58135c96b800..b66026f933b6 100644 --- a/core/xray.go +++ b/core/xray.go @@ -5,6 +5,7 @@ import ( "reflect" "sync" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/platform" @@ -211,6 +212,7 @@ func initInstanceWithConfig(config *Config, server *Instance) (bool, error) { Instance features.Feature }{ {dns.ClientType(), localdns.New()}, + {connectiontracker.FeatureType(), connectiontracker.NewService()}, {policy.ManagerType(), policy.DefaultManager{}}, {routing.RouterType(), routing.DefaultRouter{}}, {stats.ManagerType(), stats.NoopManager{}}, diff --git a/infra/conf/api.go b/infra/conf/api.go index dca34910b4ba..0a7370ea29d7 100644 --- a/infra/conf/api.go +++ b/infra/conf/api.go @@ -4,6 +4,7 @@ import ( "strings" "github.com/xtls/xray-core/app/commander" + connservice "github.com/xtls/xray-core/app/connectiontracker/command" loggerservice "github.com/xtls/xray-core/app/log/command" observatoryservice "github.com/xtls/xray-core/app/observatory/command" handlerservice "github.com/xtls/xray-core/app/proxyman/command" @@ -39,6 +40,8 @@ func (c *APIConfig) Build() (*commander.Config, error) { services = append(services, serial.ToTypedMessage(&observatoryservice.Config{})) case "routingservice": services = append(services, serial.ToTypedMessage(&routerservice.Config{})) + case "connectiontrackerservice": + services = append(services, serial.ToTypedMessage(&connservice.Config{})) } } diff --git a/main/commands/all/api/api.go b/main/commands/all/api/api.go index 85d7aa9f46c7..0bb86ec80b66 100644 --- a/main/commands/all/api/api.go +++ b/main/commands/all/api/api.go @@ -34,5 +34,9 @@ var CmdAPI = &base.Command{ cmdOnlineStats, cmdOnlineStatsIpList, cmdGetAllOnlineUsers, + cmdConnList, + cmdConnClose, + cmdConnUserStats, + cmdConnStream, }, } diff --git a/main/commands/all/api/conn_close.go b/main/commands/all/api/conn_close.go new file mode 100644 index 000000000000..9902e090d7bb --- /dev/null +++ b/main/commands/all/api/conn_close.go @@ -0,0 +1,62 @@ +package api + +import ( + "fmt" + + connService "github.com/xtls/xray-core/app/connectiontracker/command" + "github.com/xtls/xray-core/main/commands/base" +) + +var cmdConnClose = &base.Command{ + CustomFlags: true, + UsageLine: "{{.Exec}} api connclose [--server=127.0.0.1:8080] -id ", + Short: "Close an active connection by ID", + Long: ` +Force-close an active proxy connection by its numeric ID. +Use 'connlist' to find connection IDs. + +Arguments: + + -s, -server + The API server address. Default 127.0.0.1:8080 + + -t, -timeout + Timeout in seconds for calling API. Default 3 + + -id + The connection ID to close (required). + +Example: + + {{.Exec}} {{.LongName}} --server=127.0.0.1:8080 -id 42 +`, + Run: executeConnClose, +} + +func executeConnClose(cmd *base.Command, args []string) { + setSharedFlags(cmd) + id := cmd.Flag.Uint("id", 0, "") + cmd.Flag.Parse(args) + + if *id == 0 { + base.Fatalf("connection id is required and must be non-zero") + } + + conn, ctx, close := dialAPIServer() + defer close() + + client := connService.NewConnTrackerServiceClient(conn) + resp, err := client.CloseConnection(ctx, &connService.CloseConnectionRequest{ + Id: uint32(*id), + }) + if err != nil { + base.Fatalf("failed to close connection: %s", err) + } + if apiJSON { + showJSONResponse(resp) + } else if resp.Found { + fmt.Printf("connection %d closed\n", *id) + } else { + fmt.Printf("connection %d not found\n", *id) + } +} diff --git a/main/commands/all/api/conn_list.go b/main/commands/all/api/conn_list.go new file mode 100644 index 000000000000..84174015b7ae --- /dev/null +++ b/main/commands/all/api/conn_list.go @@ -0,0 +1,46 @@ +package api + +import ( + connService "github.com/xtls/xray-core/app/connectiontracker/command" + "github.com/xtls/xray-core/main/commands/base" +) + +var cmdConnList = &base.Command{ + CustomFlags: true, + UsageLine: "{{.Exec}} api connlist [--server=127.0.0.1:8080]", + Short: "List all active connections", + Long: ` +List all active proxy connections tracked by Xray. + +Arguments: + + -s, -server + The API server address. Default 127.0.0.1:8080 + + -t, -timeout + Timeout in seconds for calling API. Default 3 + + -json + Output as JSON. + +Example: + + {{.Exec}} {{.LongName}} --server=127.0.0.1:8080 +`, + Run: executeConnList, +} + +func executeConnList(cmd *base.Command, args []string) { + setSharedFlags(cmd) + cmd.Flag.Parse(args) + + conn, ctx, close := dialAPIServer() + defer close() + + client := connService.NewConnTrackerServiceClient(conn) + resp, err := client.ListConnections(ctx, &connService.ListConnectionsRequest{}) + if err != nil { + base.Fatalf("failed to list connections: %s", err) + } + showJSONResponse(resp) +} diff --git a/main/commands/all/api/conn_stream.go b/main/commands/all/api/conn_stream.go new file mode 100644 index 000000000000..81684baa4ff3 --- /dev/null +++ b/main/commands/all/api/conn_stream.go @@ -0,0 +1,77 @@ +package api + +import ( + "fmt" + "io" + "os" + + connService "github.com/xtls/xray-core/app/connectiontracker/command" + creflect "github.com/xtls/xray-core/common/reflect" + "github.com/xtls/xray-core/main/commands/base" +) + +var cmdConnStream = &base.Command{ + CustomFlags: true, + UsageLine: "{{.Exec}} api connstream [--server=127.0.0.1:8080]", + Short: "Stream live connection open/close events", + Long: ` +Subscribe to a live stream of connection lifecycle events from Xray. +Each event is printed as it occurs. Press Ctrl+C to stop. + +Arguments: + + -s, -server + The API server address. Default 127.0.0.1:8080 + + -t, -timeout + Timeout in seconds for calling API. Default 3 (use a larger value + or 0 for indefinite streaming). + + -json + Output each event as JSON. Default: human-readable. + +Example: + + {{.Exec}} {{.LongName}} --server=127.0.0.1:8080 -timeout 0 +`, + Run: executeConnStream, +} + +func executeConnStream(cmd *base.Command, args []string) { + setSharedFlags(cmd) + cmd.Flag.Parse(args) + + conn, ctx, close := dialAPIServer() + defer close() + + client := connService.NewConnTrackerServiceClient(conn) + stream, err := client.StreamConnections(ctx, &connService.StreamConnectionsRequest{}) + if err != nil { + base.Fatalf("failed to start stream: %s", err) + } + + for { + update, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + base.Fatalf("stream error: %s", err) + } + if apiJSON { + if j, ok := creflect.MarshalToJson(update, true); ok { + fmt.Println(j) + } else { + fmt.Fprintf(os.Stderr, "failed to encode event as JSON\n") + } + } else { + action := "CONNECTED" + if update.Event == connService.ConnEventType_DISCONNECTED { + action = "DISCONNECTED" + } + c := update.Conn + fmt.Printf("[%s] id=%d email=%s protocol=%s inbound=%s\n", + action, c.GetId(), c.GetEmail(), c.GetProtocol(), c.GetInboundTag()) + } + } +} diff --git a/main/commands/all/api/conn_user_stats.go b/main/commands/all/api/conn_user_stats.go new file mode 100644 index 000000000000..2d5155b27161 --- /dev/null +++ b/main/commands/all/api/conn_user_stats.go @@ -0,0 +1,57 @@ +package api + +import ( + connService "github.com/xtls/xray-core/app/connectiontracker/command" + "github.com/xtls/xray-core/main/commands/base" +) + +var cmdConnUserStats = &base.Command{ + CustomFlags: true, + UsageLine: "{{.Exec}} api connuserstats [--server=127.0.0.1:8080] -email ", + Short: "Get traffic stats and connection count for a user", + Long: ` +Retrieve the aggregated uplink bytes, downlink bytes, and active connection +count for a given user across all inbounds. + +Arguments: + + -s, -server + The API server address. Default 127.0.0.1:8080 + + -t, -timeout + Timeout in seconds for calling API. Default 3 + + -email + The user's email address (required). + + -json + Output as JSON. + +Example: + + {{.Exec}} {{.LongName}} --server=127.0.0.1:8080 -email "user@example.com" +`, + Run: executeConnUserStats, +} + +func executeConnUserStats(cmd *base.Command, args []string) { + setSharedFlags(cmd) + email := cmd.Flag.String("email", "", "") + cmd.Flag.Parse(args) + + if *email == "" { + base.Fatalf("email is required") + } + + conn, ctx, close := dialAPIServer() + defer close() + + client := connService.NewConnTrackerServiceClient(conn) + resp, err := client.GetUserStats(ctx, &connService.GetUserStatsRequest{ + Email: *email, + }) + if err != nil { + base.Fatalf("failed to get user stats: %s", err) + } + showJSONResponse(resp) +} diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 39c790c7ebd9..cec08c6c2350 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -24,8 +25,8 @@ import ( func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { d := new(DokodemoDoor) - err := core.RequireFeatures(ctx, func(pm policy.Manager) error { - return d.Init(config.(*Config), pm, session.SockoptFromContext(ctx)) + err := core.RequireFeatures(ctx, func(pm policy.Manager, trackerSvc connectiontracker.Feature) error { + return d.Init(config.(*Config), pm, session.SockoptFromContext(ctx), trackerSvc.Manager()) }) return d, err })) @@ -38,10 +39,11 @@ type DokodemoDoor struct { rewritePort net.Port portMap map[string]string sockopt *session.Sockopt + accessManager *connectiontracker.Manager } // Init initializes the DokodemoDoor instance with necessary parameters. -func (d *DokodemoDoor) Init(config *Config, pm policy.Manager, sockopt *session.Sockopt) error { +func (d *DokodemoDoor) Init(config *Config, pm policy.Manager, sockopt *session.Sockopt, accessManager *connectiontracker.Manager) error { if len(config.AllowedNetworks) == 0 { return errors.New("no network specified") } @@ -51,6 +53,7 @@ func (d *DokodemoDoor) Init(config *Config, pm policy.Manager, sockopt *session. d.portMap = config.PortMap d.policyManager = pm d.sockopt = sockopt + d.accessManager = accessManager return nil } @@ -190,12 +193,19 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st } } - if err := dispatcher.DispatchLink( - ctx, dest, &transport.Link{ - Reader: reader, - Writer: writer, - }, - ); err != nil { + link := &transport.Link{ + Reader: reader, + Writer: writer, + } + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && d.accessManager != nil { + ctx, link, accessRecord = d.accessManager.TrackAccessLink(ctx, accessMessage, link, nil) + defer d.accessManager.FinishAccessRecord(accessRecord) + } + if err := dispatcher.DispatchLink(ctx, dest, link); err != nil { + if accessRecord != nil { + d.accessManager.AbortAccessRecord(accessRecord, err) + } return errors.New("failed to dispatch request").Base(err) } return nil // Unlike Dispatch(), DispatchLink() will not return until the outbound finishes Process() diff --git a/proxy/http/server.go b/proxy/http/server.go index 054e40bf444b..caf90b08e28a 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -31,14 +32,23 @@ import ( type Server struct { config *ServerConfig policyManager policy.Manager + accessManager *connectiontracker.Manager } // NewServer creates a new HTTP inbound handler. func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { v := core.MustFromContext(ctx) + var trackerManager *connectiontracker.Manager + if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error { + trackerManager = trackerSvc.Manager() + return nil + }); err != nil { + return nil, err + } s := &Server{ config: config, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + accessManager: trackerManager, } return s, nil @@ -192,12 +202,19 @@ func (s *Server) handleConnect(ctx context.Context, _ *http.Request, buffer *buf if inbound.CanSpliceCopy == 2 { inbound.CanSpliceCopy = 1 } - if err := dispatcher.DispatchLink( - ctx, dest, &transport.Link{ - Reader: reader, - Writer: buf.NewWriter(conn), - }, - ); err != nil { + link := &transport.Link{ + Reader: reader, + Writer: buf.NewWriter(conn), + } + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && s.accessManager != nil { + ctx, link, accessRecord = s.accessManager.TrackAccessLink(ctx, accessMessage, link, nil) + defer s.accessManager.FinishAccessRecord(accessRecord) + } + if err := dispatcher.DispatchLink(ctx, dest, link); err != nil { + if accessRecord != nil { + s.accessManager.AbortAccessRecord(accessRecord, err) + } return errors.New("failed to dispatch request").Base(err) } return nil @@ -247,10 +264,21 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri ctx = session.ContextWithContent(ctx, content) + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && s.accessManager != nil { + accessRecord = s.accessManager.NewAccessRecord(accessMessage, nil) + ctx = connectiontracker.ContextWithAccessRecord(ctx, accessRecord) + defer s.accessManager.FinishAccessRecord(accessRecord) + } + link, err := dispatcher.Dispatch(ctx, dest) if err != nil { + if accessRecord != nil { + s.accessManager.AbortAccessRecord(accessRecord, err) + } return err } + link = connectiontracker.WrapAccessLink(link, accessRecord) // Plain HTTP request is not a stream. The request always finishes before response. Hense request has to be closed later. defer common.Close(link.Writer) @@ -307,6 +335,9 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri if err := task.Run(ctx, requestDone, responseDone); err != nil { common.Interrupt(link.Reader) common.Interrupt(link.Writer) + if accessRecord != nil { + s.accessManager.AbortAccessRecord(accessRecord, err) + } return errors.New("connection ends").Base(err) } diff --git a/proxy/hysteria/server.go b/proxy/hysteria/server.go index 815faca11574..529f7532aadc 100644 --- a/proxy/hysteria/server.go +++ b/proxy/hysteria/server.go @@ -2,8 +2,10 @@ package hysteria import ( "context" + "strings" "time" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -24,6 +26,8 @@ type Server struct { config *ServerConfig validator *account.Validator policyManager policy.Manager + accessManager *connectiontracker.Manager + connTracker *connectiontracker.Tracker } func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { @@ -40,10 +44,23 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { } v := core.MustFromContext(ctx) + var trackerManager *connectiontracker.Manager + if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error { + trackerManager = trackerSvc.Manager() + return nil + }); err != nil { + return nil, err + } + if trackerManager == nil { + return nil, errors.New("connection tracker feature is not available") + } + s := &Server{ config: config, validator: validator, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + accessManager: trackerManager, + connTracker: trackerManager.NewTracker(), } return s, nil @@ -58,6 +75,7 @@ func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { } func (s *Server) RemoveUser(ctx context.Context, e string) error { + s.connTracker.CancelAll(strings.ToLower(e)) return s.validator.Del(e) } @@ -81,13 +99,26 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con inbound := session.InboundFromContext(ctx) inbound.Name = "hysteria" inbound.CanSpliceCopy = 3 - inbound.User = &protocol.MemoryUser{} iConn := stat.TryUnwrapStatsConn(conn) + var useremail string + var userlevel uint32 type User interface{ User() *protocol.MemoryUser } - if v, ok := iConn.(User); ok && v.User() != nil { + if v, ok := iConn.(User); ok { inbound.User = v.User() + if inbound.User != nil { + useremail = inbound.User.Email + userlevel = inbound.User.Level + } + } + + ctx, connCancel := context.WithCancel(ctx) + defer connCancel() + if email := strings.ToLower(inbound.User.Email); email != "" { + connID, connEntry := s.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "hysteria") + defer s.connTracker.Unregister(email, connID) + conn = connectiontracker.WrapConn(conn, connEntry) } if _, ok := iConn.(*hysteria.InterConn); ok { @@ -118,7 +149,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con Writer: writer, }) } else { - sessionPolicy := s.policyManager.ForLevel(inbound.User.Level) + sessionPolicy := s.policyManager.ForLevel(userlevel) common.Must(conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake))) addr, err := ReadTCPRequest(conn) @@ -142,7 +173,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con To: dest, Status: log.AccessAccepted, Reason: "", - Email: inbound.User.Email, + Email: useremail, }) errors.LogInfo(ctx, "tunnelling request to ", dest) @@ -155,10 +186,22 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con return err } - return dispatcher.DispatchLink(ctx, dest, &transport.Link{ + link := &transport.Link{ Reader: buf.NewReader(conn), Writer: bufferedWriter, - }) + } + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && s.accessManager != nil { + ctx, link, accessRecord = s.accessManager.TrackAccessLink(ctx, accessMessage, link, connCancel) + defer s.accessManager.FinishAccessRecord(accessRecord) + } + if err := dispatcher.DispatchLink(ctx, dest, link); err != nil { + if accessRecord != nil { + s.accessManager.AbortAccessRecord(accessRecord, err) + } + return err + } + return nil } } diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 360ea38c8d53..d60e6af23c77 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -2,8 +2,10 @@ package shadowsocks import ( "context" + "strings" "time" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -25,7 +27,9 @@ type Server struct { config *ServerConfig validator *Validator policyManager policy.Manager + accessManager *connectiontracker.Manager cone bool + connTracker *connectiontracker.Tracker } // NewServer create a new Shadowsocks server. @@ -43,11 +47,24 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { } v := core.MustFromContext(ctx) + var trackerManager *connectiontracker.Manager + if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error { + trackerManager = trackerSvc.Manager() + return nil + }); err != nil { + return nil, err + } + if trackerManager == nil { + return nil, errors.New("connection tracker feature is not available") + } + s := &Server{ config: config, validator: validator, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + accessManager: trackerManager, cone: ctx.Value("cone").(bool), + connTracker: trackerManager.NewTracker(), } return s, nil @@ -60,6 +77,7 @@ func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { // RemoveUser implements proxy.UserManager.RemoveUser(). func (s *Server) RemoveUser(ctx context.Context, e string) error { + s.connTracker.CancelAll(strings.ToLower(e)) return s.validator.Del(e) } @@ -235,12 +253,27 @@ func (s *Server) handleConnection(ctx context.Context, conn stat.Connection, dis sessionPolicy = s.policyManager.ForLevel(request.User.Level) ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) + if email := strings.ToLower(request.User.Email); email != "" { + connID, connEntry := s.connTracker.RegisterWithMeta(email, cancel, inbound.Tag, "shadowsocks") + defer s.connTracker.Unregister(email, connID) + conn = connectiontracker.WrapConn(conn, connEntry) + } ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && s.accessManager != nil { + accessRecord = s.accessManager.NewAccessRecord(accessMessage, cancel) + ctx = connectiontracker.ContextWithAccessRecord(ctx, accessRecord) + defer s.accessManager.FinishAccessRecord(accessRecord) + } link, err := dispatcher.Dispatch(ctx, dest) if err != nil { + if accessRecord != nil { + s.accessManager.AbortAccessRecord(accessRecord, err) + } return err } + link = connectiontracker.WrapAccessLink(link, accessRecord) responseDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) @@ -286,6 +319,9 @@ func (s *Server) handleConnection(ctx context.Context, conn stat.Connection, dis if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil { common.Interrupt(link.Reader) common.Interrupt(link.Writer) + if accessRecord != nil { + s.accessManager.AbortAccessRecord(accessRecord, err) + } return errors.New("connection ends").Base(err) } diff --git a/proxy/shadowsocks_2022/inbound.go b/proxy/shadowsocks_2022/inbound.go index edf9857c87fe..f9a2c7fc2f6a 100644 --- a/proxy/shadowsocks_2022/inbound.go +++ b/proxy/shadowsocks_2022/inbound.go @@ -2,6 +2,7 @@ package shadowsocks_2022 import ( "context" + "strings" "time" shadowsocks "github.com/sagernet/sing-shadowsocks" @@ -12,6 +13,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -21,6 +23,7 @@ import ( "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/singbridge" + "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport/internet/stat" ) @@ -32,10 +35,12 @@ func init() { } type Inbound struct { - networks []net.Network - service shadowsocks.Service - email string - level int + networks []net.Network + service shadowsocks.Service + email string + level int + accessManager *connectiontracker.Manager + connTracker *connectiontracker.Tracker } func NewServer(ctx context.Context, config *ServerConfig) (*Inbound, error) { @@ -46,10 +51,24 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Inbound, error) { net.Network_UDP, } } + + var trackerManager *connectiontracker.Manager + if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error { + trackerManager = trackerSvc.Manager() + return nil + }); err != nil { + return nil, err + } + if trackerManager == nil { + return nil, errors.New("connection tracker feature is not available") + } + inbound := &Inbound{ - networks: networks, - email: config.Email, - level: int(config.Level), + networks: networks, + email: config.Email, + level: int(config.Level), + accessManager: trackerManager, + connTracker: trackerManager.NewTracker(), } if !C.Contains(shadowaead_2022.List, config.Method) { return nil, errors.New("unsupported method ", config.Method) @@ -116,16 +135,42 @@ func (i *Inbound) NewConnection(ctx context.Context, conn net.Conn, metadata M.M Email: i.email, }) errors.LogInfo(ctx, "tunnelling request to tcp:", metadata.Destination) + + ctx, connCancel := context.WithCancel(ctx) + defer connCancel() + if email := strings.ToLower(i.email); email != "" { + connID, connEntry := i.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "shadowsocks-2022") + defer i.connTracker.Unregister(email, connID) + conn = connectiontracker.WrapConn(conn, connEntry) + } + dispatcher := session.DispatcherFromContext(ctx) destination, err := singbridge.ToDestination(metadata.Destination, net.Network_TCP) if err != nil { return err } + + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && i.accessManager != nil { + accessRecord = i.accessManager.NewAccessRecord(accessMessage, connCancel) + ctx = connectiontracker.ContextWithAccessRecord(ctx, accessRecord) + defer i.accessManager.FinishAccessRecord(accessRecord) + } link, err := dispatcher.Dispatch(ctx, destination) if err != nil { + if accessRecord != nil { + i.accessManager.AbortAccessRecord(accessRecord, err) + } return err } - return singbridge.CopyConn(ctx, nil, link, conn) + link = connectiontracker.WrapAccessLink(link, accessRecord) + if err := singbridge.CopyConn(ctx, nil, link, conn); err != nil { + if accessRecord != nil { + i.accessManager.AbortAccessRecord(accessRecord, err) + } + return err + } + return nil } func (i *Inbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { @@ -141,15 +186,34 @@ func (i *Inbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, me Email: i.email, }) errors.LogInfo(ctx, "tunnelling request to udp:", metadata.Destination) + + ctx, connCancel := context.WithCancel(ctx) + defer connCancel() + if email := strings.ToLower(i.email); email != "" { + connID, connEntry := i.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "shadowsocks-2022") + defer i.connTracker.Unregister(email, connID) + conn = connectiontracker.WrapPacketConn(conn, connEntry) + } + dispatcher := session.DispatcherFromContext(ctx) destination, err := singbridge.ToDestination(metadata.Destination, net.Network_UDP) if err != nil { return err } + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && i.accessManager != nil { + accessRecord = i.accessManager.NewAccessRecord(accessMessage, connCancel) + ctx = connectiontracker.ContextWithAccessRecord(ctx, accessRecord) + defer i.accessManager.FinishAccessRecord(accessRecord) + } link, err := dispatcher.Dispatch(ctx, destination) if err != nil { + if accessRecord != nil { + i.accessManager.AbortAccessRecord(accessRecord, err) + } return err } + link = connectiontracker.WrapAccessLink(link, accessRecord) outConn := &singbridge.PacketConnWrapper{ Reader: link.Reader, Writer: link.Writer, @@ -158,7 +222,13 @@ func (i *Inbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, me common.Interrupt(link.Reader) }, 300*time.Second), } - return bufio.CopyPacketConn(ctx, conn, outConn) + if err := bufio.CopyPacketConn(ctx, conn, outConn); err != nil { + if accessRecord != nil { + i.accessManager.AbortAccessRecord(accessRecord, err) + } + return err + } + return nil } func (i *Inbound) NewError(ctx context.Context, err error) { diff --git a/proxy/shadowsocks_2022/inbound_multi.go b/proxy/shadowsocks_2022/inbound_multi.go index d6d68c09ae29..aa3bef1d3ed8 100644 --- a/proxy/shadowsocks_2022/inbound_multi.go +++ b/proxy/shadowsocks_2022/inbound_multi.go @@ -16,6 +16,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -26,6 +27,7 @@ import ( "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/singbridge" "github.com/xtls/xray-core/common/uuid" + "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport/internet/stat" ) @@ -38,9 +40,11 @@ func init() { type MultiUserInbound struct { sync.Mutex - networks []net.Network - users []*protocol.MemoryUser - service *shadowaead_2022.MultiService[int] + networks []net.Network + users []*protocol.MemoryUser + service *shadowaead_2022.MultiService[int] + accessManager *connectiontracker.Manager + connTracker *connectiontracker.Tracker } func NewMultiServer(ctx context.Context, config *MultiUserServerConfig) (*MultiUserInbound, error) { @@ -64,9 +68,22 @@ func NewMultiServer(ctx context.Context, config *MultiUserServerConfig) (*MultiU memUsers = append(memUsers, u) } + var trackerManager *connectiontracker.Manager + if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error { + trackerManager = trackerSvc.Manager() + return nil + }); err != nil { + return nil, err + } + if trackerManager == nil { + return nil, errors.New("connection tracker feature is not available") + } + inbound := &MultiUserInbound{ - networks: networks, - users: memUsers, + networks: networks, + users: memUsers, + accessManager: trackerManager, + connTracker: trackerManager.NewTracker(), } if config.Key == "" { return nil, errors.New("missing key") @@ -121,6 +138,8 @@ func (i *MultiUserInbound) RemoveUser(ctx context.Context, email string) error { return errors.New("Email must not be empty.") } + i.connTracker.CancelAll(strings.ToLower(email)) + i.Lock() defer i.Unlock() @@ -238,16 +257,41 @@ func (i *MultiUserInbound) NewConnection(ctx context.Context, conn net.Conn, met Email: user.Email, }) errors.LogInfo(ctx, "tunnelling request to tcp:", metadata.Destination) + + ctx, connCancel := context.WithCancel(ctx) + defer connCancel() + if email := strings.ToLower(user.Email); email != "" { + connID, connEntry := i.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "shadowsocks-2022") + defer i.connTracker.Unregister(email, connID) + conn = connectiontracker.WrapConn(conn, connEntry) + } + dispatcher := session.DispatcherFromContext(ctx) destination, err := singbridge.ToDestination(metadata.Destination, net.Network_TCP) if err != nil { return err } + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && i.accessManager != nil { + accessRecord = i.accessManager.NewAccessRecord(accessMessage, connCancel) + ctx = connectiontracker.ContextWithAccessRecord(ctx, accessRecord) + defer i.accessManager.FinishAccessRecord(accessRecord) + } link, err := dispatcher.Dispatch(ctx, destination) if err != nil { + if accessRecord != nil { + i.accessManager.AbortAccessRecord(accessRecord, err) + } return err } - return singbridge.CopyConn(ctx, conn, link, conn) + link = connectiontracker.WrapAccessLink(link, accessRecord) + if err := singbridge.CopyConn(ctx, conn, link, conn); err != nil { + if accessRecord != nil { + i.accessManager.AbortAccessRecord(accessRecord, err) + } + return err + } + return nil } func (i *MultiUserInbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { @@ -262,15 +306,34 @@ func (i *MultiUserInbound) NewPacketConnection(ctx context.Context, conn N.Packe Email: user.Email, }) errors.LogInfo(ctx, "tunnelling request to udp:", metadata.Destination) + + ctx, connCancel := context.WithCancel(ctx) + defer connCancel() + if email := strings.ToLower(user.Email); email != "" { + connID, connEntry := i.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "shadowsocks-2022") + defer i.connTracker.Unregister(email, connID) + conn = connectiontracker.WrapPacketConn(conn, connEntry) + } + dispatcher := session.DispatcherFromContext(ctx) destination, err := singbridge.ToDestination(metadata.Destination, net.Network_UDP) if err != nil { return err } + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && i.accessManager != nil { + accessRecord = i.accessManager.NewAccessRecord(accessMessage, connCancel) + ctx = connectiontracker.ContextWithAccessRecord(ctx, accessRecord) + defer i.accessManager.FinishAccessRecord(accessRecord) + } link, err := dispatcher.Dispatch(ctx, destination) if err != nil { + if accessRecord != nil { + i.accessManager.AbortAccessRecord(accessRecord, err) + } return err } + link = connectiontracker.WrapAccessLink(link, accessRecord) outConn := &singbridge.PacketConnWrapper{ Reader: link.Reader, Writer: link.Writer, @@ -279,7 +342,13 @@ func (i *MultiUserInbound) NewPacketConnection(ctx context.Context, conn N.Packe common.Interrupt(link.Reader) }, 300*time.Second), } - return bufio.CopyPacketConn(ctx, conn, outConn) + if err := bufio.CopyPacketConn(ctx, conn, outConn); err != nil { + if accessRecord != nil { + i.accessManager.AbortAccessRecord(accessRecord, err) + } + return err + } + return nil } func (i *MultiUserInbound) NewError(ctx context.Context, err error) { diff --git a/proxy/shadowsocks_2022/inbound_relay.go b/proxy/shadowsocks_2022/inbound_relay.go index 4ca5e20753c7..1521f3930634 100644 --- a/proxy/shadowsocks_2022/inbound_relay.go +++ b/proxy/shadowsocks_2022/inbound_relay.go @@ -14,6 +14,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -24,6 +25,7 @@ import ( "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/singbridge" "github.com/xtls/xray-core/common/uuid" + "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport/internet/stat" ) @@ -35,9 +37,11 @@ func init() { } type RelayInbound struct { - networks []net.Network - destinations []*RelayDestination - service *shadowaead_2022.RelayService[int] + networks []net.Network + destinations []*RelayDestination + service *shadowaead_2022.RelayService[int] + accessManager *connectiontracker.Manager + connTracker *connectiontracker.Tracker } func NewRelayServer(ctx context.Context, config *RelayServerConfig) (*RelayInbound, error) { @@ -48,9 +52,23 @@ func NewRelayServer(ctx context.Context, config *RelayServerConfig) (*RelayInbou net.Network_UDP, } } + + var trackerManager *connectiontracker.Manager + if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error { + trackerManager = trackerSvc.Manager() + return nil + }); err != nil { + return nil, err + } + if trackerManager == nil { + return nil, errors.New("connection tracker feature is not available") + } + inbound := &RelayInbound{ - networks: networks, - destinations: config.Destinations, + networks: networks, + destinations: config.Destinations, + accessManager: trackerManager, + connTracker: trackerManager.NewTracker(), } if !C.Contains(shadowaead_2022.List, config.Method) || !strings.Contains(config.Method, "aes") { return nil, errors.New("unsupported method ", config.Method) @@ -139,16 +157,42 @@ func (i *RelayInbound) NewConnection(ctx context.Context, conn net.Conn, metadat Email: user.Email, }) errors.LogInfo(ctx, "tunnelling request to tcp:", metadata.Destination) + + ctx, connCancel := context.WithCancel(ctx) + defer connCancel() + if email := strings.ToLower(user.Email); email != "" { + connID, connEntry := i.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "shadowsocks-2022-relay") + defer i.connTracker.Unregister(email, connID) + conn = connectiontracker.WrapConn(conn, connEntry) + } + dispatcher := session.DispatcherFromContext(ctx) destination, err := singbridge.ToDestination(metadata.Destination, net.Network_TCP) if err != nil { return err } + + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && i.accessManager != nil { + accessRecord = i.accessManager.NewAccessRecord(accessMessage, connCancel) + ctx = connectiontracker.ContextWithAccessRecord(ctx, accessRecord) + defer i.accessManager.FinishAccessRecord(accessRecord) + } link, err := dispatcher.Dispatch(ctx, destination) if err != nil { + if accessRecord != nil { + i.accessManager.AbortAccessRecord(accessRecord, err) + } return err } - return singbridge.CopyConn(ctx, nil, link, conn) + link = connectiontracker.WrapAccessLink(link, accessRecord) + if err := singbridge.CopyConn(ctx, nil, link, conn); err != nil { + if accessRecord != nil { + i.accessManager.AbortAccessRecord(accessRecord, err) + } + return err + } + return nil } func (i *RelayInbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { @@ -166,15 +210,34 @@ func (i *RelayInbound) NewPacketConnection(ctx context.Context, conn N.PacketCon Email: user.Email, }) errors.LogInfo(ctx, "tunnelling request to udp:", metadata.Destination) + + ctx, connCancel := context.WithCancel(ctx) + defer connCancel() + if email := strings.ToLower(user.Email); email != "" { + connID, connEntry := i.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "shadowsocks-2022-relay") + defer i.connTracker.Unregister(email, connID) + conn = connectiontracker.WrapPacketConn(conn, connEntry) + } + dispatcher := session.DispatcherFromContext(ctx) destination, err := singbridge.ToDestination(metadata.Destination, net.Network_UDP) if err != nil { return err } + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && i.accessManager != nil { + accessRecord = i.accessManager.NewAccessRecord(accessMessage, connCancel) + ctx = connectiontracker.ContextWithAccessRecord(ctx, accessRecord) + defer i.accessManager.FinishAccessRecord(accessRecord) + } link, err := dispatcher.Dispatch(ctx, destination) if err != nil { + if accessRecord != nil { + i.accessManager.AbortAccessRecord(accessRecord, err) + } return err } + link = connectiontracker.WrapAccessLink(link, accessRecord) outConn := &singbridge.PacketConnWrapper{ Reader: link.Reader, Writer: link.Writer, @@ -183,7 +246,13 @@ func (i *RelayInbound) NewPacketConnection(ctx context.Context, conn N.PacketCon common.Interrupt(link.Reader) }, 300*time.Second), } - return bufio.CopyPacketConn(ctx, conn, outConn) + if err := bufio.CopyPacketConn(ctx, conn, outConn); err != nil { + if accessRecord != nil { + i.accessManager.AbortAccessRecord(accessRecord, err) + } + return err + } + return nil } func (i *RelayInbound) NewError(ctx context.Context, err error) { diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 53049dfe2342..951cc8eac360 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -6,6 +6,7 @@ import ( "io" "time" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -28,6 +29,7 @@ import ( type Server struct { config *ServerConfig policyManager policy.Manager + accessManager *connectiontracker.Manager cone bool httpServer *http.Server } @@ -35,9 +37,17 @@ type Server struct { // NewServer creates a new Server object. func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { v := core.MustFromContext(ctx) + var trackerManager *connectiontracker.Manager + if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error { + trackerManager = trackerSvc.Manager() + return nil + }); err != nil { + return nil, err + } s := &Server{ config: config, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + accessManager: trackerManager, cone: ctx.Value("cone").(bool), } httpConfig := &http.ServerConfig{ @@ -46,7 +56,11 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { if config.AuthType == AuthType_PASSWORD { httpConfig.Accounts = config.Accounts } - s.httpServer, _ = http.NewServer(ctx, httpConfig) + httpServer, err := http.NewServer(ctx, httpConfig) + if err != nil { + return nil, err + } + s.httpServer = httpServer return s, nil } @@ -153,12 +167,19 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche if inbound.CanSpliceCopy == 2 { inbound.CanSpliceCopy = 1 } - if err := dispatcher.DispatchLink( - ctx, dest, &transport.Link{ - Reader: reader, - Writer: buf.NewWriter(conn), - }, - ); err != nil { + link := &transport.Link{ + Reader: reader, + Writer: buf.NewWriter(conn), + } + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && s.accessManager != nil { + ctx, link, accessRecord = s.accessManager.TrackAccessLink(ctx, accessMessage, link, nil) + defer s.accessManager.FinishAccessRecord(accessRecord) + } + if err := dispatcher.DispatchLink(ctx, dest, link); err != nil { + if accessRecord != nil { + s.accessManager.AbortAccessRecord(accessRecord, err) + } return errors.New("failed to dispatch request").Base(err) } return nil diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index d5979e529365..1632381be3b0 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -38,7 +39,9 @@ type Server struct { policyManager policy.Manager validator *Validator fallbacks map[string]map[string]map[string]*Fallback // or nil + accessManager *connectiontracker.Manager cone bool + connTracker *connectiontracker.Tracker } // NewServer creates a new trojan inbound handler. @@ -56,10 +59,23 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { } v := core.MustFromContext(ctx) + var trackerManager *connectiontracker.Manager + if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error { + trackerManager = trackerSvc.Manager() + return nil + }); err != nil { + return nil, err + } + if trackerManager == nil { + return nil, errors.New("connection tracker feature is not available") + } + server := &Server{ policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), validator: validator, + accessManager: trackerManager, cone: ctx.Value("cone").(bool), + connTracker: trackerManager.NewTracker(), } if config.Fallbacks != nil { @@ -122,6 +138,7 @@ func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { // RemoveUser implements proxy.UserManager.RemoveUser(). func (s *Server) RemoveUser(ctx context.Context, e string) error { + s.connTracker.CancelAll(strings.ToLower(e)) return s.validator.Del(e) } @@ -228,6 +245,14 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con inbound.User = user sessionPolicy = s.policyManager.ForLevel(user.Level) + ctx, connCancel := context.WithCancel(ctx) + defer connCancel() + if email := strings.ToLower(user.Email); email != "" { + connID, connEntry := s.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "trojan") + defer s.connTracker.Unregister(email, connID) + conn = connectiontracker.WrapConn(conn, connEntry) + } + if destination.Network == net.Network_UDP { // handle udp request return s.handleUDPPayload(ctx, sessionPolicy, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher) } @@ -328,11 +353,21 @@ func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Sess ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && s.accessManager != nil { + accessRecord = s.accessManager.NewAccessRecord(accessMessage, cancel) + ctx = connectiontracker.ContextWithAccessRecord(ctx, accessRecord) + defer s.accessManager.FinishAccessRecord(accessRecord) + } link, err := dispatcher.Dispatch(ctx, destination) if err != nil { + if accessRecord != nil { + s.accessManager.AbortAccessRecord(accessRecord, err) + } return errors.New("failed to dispatch request to ", destination).Base(err) } + link = connectiontracker.WrapAccessLink(link, accessRecord) requestDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) @@ -355,6 +390,9 @@ func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Sess if err := task.Run(ctx, requestDonePost, responseDone); err != nil { common.Must(common.Interrupt(link.Reader)) common.Must(common.Interrupt(link.Writer)) + if accessRecord != nil { + s.accessManager.AbortAccessRecord(accessRecord, err) + } return errors.New("connection ends").Base(err) } diff --git a/proxy/tun/handler.go b/proxy/tun/handler.go index d552eca537f7..99389834e3c5 100644 --- a/proxy/tun/handler.go +++ b/proxy/tun/handler.go @@ -6,6 +6,7 @@ import ( "strings" "syscall" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" c "github.com/xtls/xray-core/common/ctx" @@ -30,6 +31,7 @@ type Handler struct { tun Tun policyManager policy.Manager dispatcher routing.Dispatcher + accessManager *connectiontracker.Manager tag string sniffingRequest session.SniffingRequest } @@ -46,7 +48,7 @@ var _ ConnectionHandler = (*Handler)(nil) var _ common.Runnable = (*Handler)(nil) // Init the Handler instance with necessary parameters -func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routing.Dispatcher) error { +func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routing.Dispatcher, accessManager *connectiontracker.Manager) error { // Retrieve tag and sniffing config from context (set by AlwaysOnInboundHandler) if inbound := session.InboundFromContext(ctx); inbound != nil { t.tag = inbound.Tag @@ -58,6 +60,7 @@ func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routin t.ctx = core.ToBackgroundDetachedContext(ctx) t.policyManager = pm t.dispatcher = dispatcher + t.accessManager = accessManager return nil } @@ -172,7 +175,15 @@ func (t *Handler) HandleConnection(conn net.Conn, destination net.Destination) { Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)}, Writer: buf.NewWriter(conn), } + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && t.accessManager != nil { + ctx, link, accessRecord = t.accessManager.TrackAccessLink(ctx, accessMessage, link, cancel) + defer t.accessManager.FinishAccessRecord(accessRecord) + } if err := t.dispatcher.DispatchLink(ctx, destination, link); err != nil { + if accessRecord != nil { + t.accessManager.AbortAccessRecord(accessRecord, err) + } errors.LogError(ctx, errors.New("connection closed").Base(err)) } } @@ -198,8 +209,8 @@ func (t *Handler) Process(ctx context.Context, network net.Network, conn stat.Co func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { t := &Handler{config: config.(*Config)} - err := core.RequireFeatures(ctx, func(pm policy.Manager, dispatcher routing.Dispatcher) error { - return t.Init(ctx, pm, dispatcher) + err := core.RequireFeatures(ctx, func(pm policy.Manager, dispatcher routing.Dispatcher, trackerSvc connectiontracker.Feature) error { + return t.Init(ctx, pm, dispatcher, trackerSvc.Manager()) }) return t, err })) diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 2d10a0060f48..1d22378b1e94 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -12,6 +12,7 @@ import ( "time" "unsafe" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/app/reverse" "github.com/xtls/xray-core/common" @@ -83,6 +84,8 @@ type Handler struct { observer features.Feature defaultDispatcher routing.Dispatcher ctx context.Context + accessManager *connectiontracker.Manager + connTracker *connectiontracker.Tracker fallbacks map[string]map[string]map[string]*Fallback // or nil // regexps map[string]*regexp.Regexp // or nil } @@ -90,6 +93,17 @@ type Handler struct { // New creates a new VLess inbound handler. func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Validator) (*Handler, error) { v := core.MustFromContext(ctx) + var trackerManager *connectiontracker.Manager + if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error { + trackerManager = trackerSvc.Manager() + return nil + }); err != nil { + return nil, err + } + if trackerManager == nil { + return nil, errors.New("connection tracker feature is not available") + } + handler := &Handler{ inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager), policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), @@ -99,6 +113,8 @@ func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Val observer: v.GetFeature(extension.ObservatoryType()), defaultDispatcher: v.GetFeature(routing.DispatcherType()).(routing.Dispatcher), ctx: ctx, + accessManager: trackerManager, + connTracker: trackerManager.NewTracker(), } if config.Decryption != "" && config.Decryption != "none" { @@ -243,6 +259,7 @@ func (h *Handler) AddUser(ctx context.Context, u *protocol.MemoryUser) error { // RemoveUser implements proxy.UserManager.RemoveUser(). func (h *Handler) RemoveUser(ctx context.Context, e string) error { + h.connTracker.CancelAll(strings.ToLower(e)) h.RemoveReverse(h.validator.GetByEmail(e)) return h.validator.Del(e) } @@ -529,6 +546,18 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s } errors.LogInfo(ctx, "received request for ", request.Destination()) + ctx, connCancel := context.WithCancel(ctx) + defer connCancel() + if email := strings.ToLower(request.User.Email); email != "" { + inboundTag := "" + if ib := session.InboundFromContext(ctx); ib != nil { + inboundTag = ib.Tag + } + connID, connEntry := h.connTracker.RegisterWithMeta(email, connCancel, inboundTag, "vless") + defer h.connTracker.Unregister(email, connID) + connection = connectiontracker.WrapConn(connection, connEntry) + } + inbound := session.InboundFromContext(ctx) if inbound == nil { panic("no inbound metadata") @@ -609,6 +638,13 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s ctx = session.ContextWithAllowedNetwork(ctx, net.Network_UDP) } + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && h.accessManager != nil { + accessRecord = h.accessManager.NewAccessRecord(accessMessage, connCancel) + ctx = connectiontracker.ContextWithAccessRecord(ctx, accessRecord) + defer h.accessManager.FinishAccessRecord(accessRecord) + } + trafficState := proxy.NewTrafficState(userSentID) clientReader := encoding.DecodeBodyAddons(reader, request, requestAddons) if requestAddons.Flow == vless.XRV { @@ -617,17 +653,35 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s bufferWriter := buf.NewBufferedWriter(buf.NewWriter(connection)) if err := encoding.EncodeResponseHeader(bufferWriter, request, responseAddons); err != nil { + if accessRecord != nil { + h.accessManager.AbortAccessRecord(accessRecord, err) + } return errors.New("failed to encode response header").Base(err).AtWarning() } clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, false, ctx, connection, nil) bufferWriter.SetFlushNext() + if accessRecord != nil { + accessLink := connectiontracker.WrapAccessLink(&transport.Link{ + Reader: clientReader, + Writer: clientWriter, + }, accessRecord) + clientReader = accessLink.Reader + clientWriter = accessLink.Writer + } + if request.Command == protocol.RequestCommandRvs { r, err := h.GetReverse(account) if err != nil { return err } - return r.NewMux(ctx, dispatcher.WrapLink(ctx, h.policyManager, h.stats, &transport.Link{Reader: clientReader, Writer: clientWriter}), h.observer) + if err := r.NewMux(ctx, dispatcher.WrapLink(ctx, h.policyManager, h.stats, &transport.Link{Reader: clientReader, Writer: clientWriter}), h.observer); err != nil { + if accessRecord != nil { + h.accessManager.AbortAccessRecord(accessRecord, err) + } + return err + } + return nil } if err := dispatch.DispatchLink( @@ -636,6 +690,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s Writer: clientWriter, }, ); err != nil { + if accessRecord != nil { + h.accessManager.AbortAccessRecord(accessRecord, err) + } return errors.New("failed to dispatch request").Base(err) } return nil diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index e084a881e8ef..ae94c1699888 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -107,17 +108,32 @@ type Handler struct { clients *vmess.TimedUserValidator usersByEmail *userByEmail sessionHistory *encoding.SessionHistory + accessManager *connectiontracker.Manager + connTracker *connectiontracker.Tracker } // New creates a new VMess inbound handler. func New(ctx context.Context, config *Config) (*Handler, error) { v := core.MustFromContext(ctx) + var trackerManager *connectiontracker.Manager + if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error { + trackerManager = trackerSvc.Manager() + return nil + }); err != nil { + return nil, err + } + if trackerManager == nil { + return nil, errors.New("connection tracker feature is not available") + } + handler := &Handler{ policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager), clients: vmess.NewTimedUserValidator(), usersByEmail: newUserByEmail(config.GetDefaultValue()), sessionHistory: encoding.NewSessionHistory(), + accessManager: trackerManager, + connTracker: trackerManager.NewTracker(), } for _, user := range config.User { @@ -181,6 +197,7 @@ func (h *Handler) RemoveUser(ctx context.Context, email string) error { if !h.usersByEmail.Remove(email) { return errors.New("User ", email, " not found.") } + h.connTracker.CancelAll(strings.ToLower(email)) h.clients.Remove(email) return nil } @@ -277,12 +294,27 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) + if email := strings.ToLower(request.User.Email); email != "" { + connID, connEntry := h.connTracker.RegisterWithMeta(email, cancel, inbound.Tag, "vmess") + defer h.connTracker.Unregister(email, connID) + connection = connectiontracker.WrapConn(connection, connEntry) + } ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && h.accessManager != nil { + accessRecord = h.accessManager.NewAccessRecord(accessMessage, cancel) + ctx = connectiontracker.ContextWithAccessRecord(ctx, accessRecord) + defer h.accessManager.FinishAccessRecord(accessRecord) + } link, err := dispatcher.Dispatch(ctx, request.Destination()) if err != nil { + if accessRecord != nil { + h.accessManager.AbortAccessRecord(accessRecord, err) + } return errors.New("failed to dispatch request to ", request.Destination()).Base(err) } + link = connectiontracker.WrapAccessLink(link, accessRecord) requestDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) @@ -313,6 +345,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if err := task.Run(ctx, requestDonePost, responseDone); err != nil { common.Interrupt(link.Reader) common.Interrupt(link.Writer) + if accessRecord != nil { + h.accessManager.AbortAccessRecord(accessRecord, err) + } return errors.New("connection ends").Base(err) } diff --git a/proxy/wireguard/server.go b/proxy/wireguard/server.go index 876d749f7d2b..fa6265676644 100644 --- a/proxy/wireguard/server.go +++ b/proxy/wireguard/server.go @@ -3,6 +3,7 @@ package wireguard import ( "context" + "github.com/xtls/xray-core/app/connectiontracker" "github.com/xtls/xray-core/common/buf" c "github.com/xtls/xray-core/common/ctx" "github.com/xtls/xray-core/common/errors" @@ -24,6 +25,7 @@ type Server struct { info routingInfo policyManager policy.Manager + accessManager *connectiontracker.Manager } type routingInfo struct { @@ -55,6 +57,12 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) { }, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), } + if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error { + server.accessManager = trackerSvc.Manager() + return nil + }); err != nil { + return nil, err + } tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection) if err != nil { @@ -158,11 +166,21 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { Reason: "", }) - err := s.info.dispatcher.DispatchLink(ctx, dest, &transport.Link{ + link := &transport.Link{ Reader: buf.NewReader(conn), Writer: buf.NewWriter(conn), - }) + } + var accessRecord *connectiontracker.AccessRecord + if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil && s.accessManager != nil { + ctx, link, accessRecord = s.accessManager.TrackAccessLink(ctx, accessMessage, link, cancel) + defer s.accessManager.FinishAccessRecord(accessRecord) + } + err := s.info.dispatcher.DispatchLink(ctx, dest, link) + if err != nil { + if accessRecord != nil { + s.accessManager.AbortAccessRecord(accessRecord, err) + } errors.LogInfoInner(ctx, err, "connection ends") } diff --git a/testing/scenarios/command_test.go b/testing/scenarios/command_test.go index 8f82bb2b447a..9c12f2b8768b 100644 --- a/testing/scenarios/command_test.go +++ b/testing/scenarios/command_test.go @@ -491,6 +491,195 @@ func TestCommanderAddRemoveUser(t *testing.T) { } } +// TestRemoveUserClosesExistingConnections verifies that calling RemoveUser via +// the gRPC API terminates connections that are currently active for that user. +func TestRemoveUserClosesExistingConnections(t *testing.T) { + tcpServer := tcp.Server{ + MsgProcessor: xor, + } + dest, err := tcpServer.Start() + common.Must(err) + defer tcpServer.Close() + + userID := protocol.NewID(uuid.New()) + cmdPort := tcp.PickPort() + serverPort := tcp.PickPort() + + serverConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&commander.Config{ + Tag: "api", + Service: []*serial.TypedMessage{ + serial.ToTypedMessage(&command.Config{}), + }, + }), + serial.ToTypedMessage(&router.Config{ + Rule: []*router.RoutingRule{ + { + InboundTag: []string{"api"}, + TargetTag: &router.RoutingRule_Tag{ + Tag: "api", + }, + }, + }, + }), + serial.ToTypedMessage(&policy.Config{ + Level: map[uint32]*policy.Policy{ + 0: { + Timeout: &policy.Policy_Timeout{ + UplinkOnly: &policy.Second{Value: 0}, + DownlinkOnly: &policy.Second{Value: 0}, + }, + }, + }, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + Tag: "v", + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &net.PortList{Range: []*net.PortRange{net.SinglePortRange(serverPort)}}, + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&inbound.Config{ + User: []*protocol.User{ + { + Email: "test@example.com", + Account: serial.ToTypedMessage(&vmess.Account{ + Id: userID.String(), + }), + }, + }, + }), + }, + { + Tag: "api", + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &net.PortList{Range: []*net.PortRange{net.SinglePortRange(cmdPort)}}, + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + RewriteAddress: net.NewIPOrDomain(dest.Address), + RewritePort: uint32(dest.Port), + AllowedNetworks: []net.Network{net.Network_TCP}, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&freedom.Config{ + FinalRules: []*freedom.FinalRuleConfig{{Action: freedom.RuleAction_Allow}}, + }), + }, + }, + } + + clientPort := tcp.PickPort() + clientConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&policy.Config{ + Level: map[uint32]*policy.Policy{ + 0: { + Timeout: &policy.Policy_Timeout{ + UplinkOnly: &policy.Second{Value: 0}, + DownlinkOnly: &policy.Second{Value: 0}, + }, + }, + }, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &net.PortList{Range: []*net.PortRange{net.SinglePortRange(clientPort)}}, + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + RewriteAddress: net.NewIPOrDomain(dest.Address), + RewritePort: uint32(dest.Port), + AllowedNetworks: []net.Network{net.Network_TCP}, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&outbound.Config{ + Receiver: &protocol.ServerEndpoint{ + Address: net.NewIPOrDomain(net.LocalHostIP), + Port: uint32(serverPort), + User: &protocol.User{ + Account: serial.ToTypedMessage(&vmess.Account{ + Id: userID.String(), + SecuritySettings: &protocol.SecurityConfig{ + Type: protocol.SecurityType_AES128_GCM, + }, + }), + }, + }, + }), + }, + }, + } + + servers, err := InitializeServerConfigs(serverConfig, clientConfig) + common.Must(err) + defer CloseAllServers(servers) + + // Open a raw TCP connection through the proxy chain (dokodemo → VMess → freedom → echo). + conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: int(clientPort), + }) + common.Must(err) + defer conn.Close() + + // Exchange data to confirm the connection is fully established through VMess. + payload := make([]byte, 1024) + for i := range payload { + payload[i] = byte(i % 256) + } + if _, err := conn.Write(payload); err != nil { + t.Fatal("write failed:", err) + } + response := make([]byte, len(payload)) + conn.SetReadDeadline(time.Now().Add(time.Second * 5)) + if _, err := io.ReadFull(conn, response); err != nil { + t.Fatal("initial echo failed:", err) + } + conn.SetReadDeadline(time.Time{}) + + // Connect to the gRPC command port and remove the user. + cmdConn, err := grpc.Dial(fmt.Sprintf("127.0.0.1:%d", cmdPort), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + ) + common.Must(err) + defer cmdConn.Close() + + hsClient := command.NewHandlerServiceClient(cmdConn) + _, err = hsClient.AlterInbound(context.Background(), &command.AlterInboundRequest{ + Tag: "v", + Operation: serial.ToTypedMessage(&command.RemoveUserOperation{Email: "test@example.com"}), + }) + if err != nil { + t.Fatal("RemoveUser failed:", err) + } + + // The server must close the active connection after user removal. + // Give the cancellation chain (cancel → task.Run → conn.Close) time to propagate, + // then verify the connection is gone via a read with a generous deadline. + conn.SetReadDeadline(time.Now().Add(time.Second * 3)) + buf := make([]byte, 1) + _, readErr := conn.Read(buf) + if readErr == nil { + t.Fatal("expected connection to be closed after RemoveUser, but read returned data") + } + if nerr, ok := readErr.(interface{ Timeout() bool }); ok && nerr.Timeout() { + t.Fatalf("connection was NOT closed within 3 s of RemoveUser (read deadline hit): %v", readErr) + } + // Any non-timeout error (io.EOF, connection reset) confirms the server side closed. +} + func TestCommanderStats(t *testing.T) { tcpServer := tcp.Server{ MsgProcessor: xor, diff --git a/testing/scenarios/common.go b/testing/scenarios/common.go index a6eea215db84..3a85a3d3778c 100644 --- a/testing/scenarios/common.go +++ b/testing/scenarios/common.go @@ -98,6 +98,8 @@ var ( testBinaryPath string testBinaryCleanFn func() testBinaryPathGen sync.Once + testBinaryBuild sync.Once + testBinaryErr error ) func genTestBinaryPath() { @@ -121,6 +123,17 @@ func genTestBinaryPath() { }) } +func buildTestBinary(build func() error) error { + genTestBinaryPath() + testBinaryBuild.Do(func() { + if _, err := os.Stat(testBinaryPath); err == nil { + return + } + testBinaryErr = build() + }) + return testBinaryErr +} + func GetSourcePath() string { return filepath.Join("github.com", "xtls", "xray-core", "main") } diff --git a/testing/scenarios/common_coverage.go b/testing/scenarios/common_coverage.go index 160df47c0bca..411cf8bed3f3 100644 --- a/testing/scenarios/common_coverage.go +++ b/testing/scenarios/common_coverage.go @@ -12,13 +12,10 @@ import ( ) func BuildXray() error { - genTestBinaryPath() - if _, err := os.Stat(testBinaryPath); err == nil { - return nil - } - - cmd := exec.Command("go", "test", "-tags", "coverage coveragemain", "-coverpkg", "github.com/xtls/xray-core/...", "-c", "-o", testBinaryPath, GetSourcePath()) - return cmd.Run() + return buildTestBinary(func() error { + cmd := exec.Command("go", "test", "-tags", "coverage coveragemain", "-coverpkg", "github.com/xtls/xray-core/...", "-c", "-o", testBinaryPath, GetSourcePath()) + return cmd.Run() + }) } func RunXrayProtobuf(config []byte) *exec.Cmd { diff --git a/testing/scenarios/common_regular.go b/testing/scenarios/common_regular.go index 19efc7135454..394791634b47 100644 --- a/testing/scenarios/common_regular.go +++ b/testing/scenarios/common_regular.go @@ -11,16 +11,13 @@ import ( ) func BuildXray() error { - genTestBinaryPath() - if _, err := os.Stat(testBinaryPath); err == nil { - return nil - } - - fmt.Printf("Building Xray into path (%s)\n", testBinaryPath) - cmd := exec.Command("go", "build", "-o="+testBinaryPath, GetSourcePath()) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - return cmd.Run() + return buildTestBinary(func() error { + fmt.Printf("Building Xray into path (%s)\n", testBinaryPath) + cmd := exec.Command("go", "build", "-o="+testBinaryPath, GetSourcePath()) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() + }) } func RunXrayProtobuf(config []byte) *exec.Cmd {