Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions app/connectiontracker/access.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
package connectiontracker

import (
"context"
"sync/atomic"
"time"

"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
clog "github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/transport"
)

type accessRecordKey struct{}

// AccessRecord captures accepted-request access-log state until the request
// finishes and a final log line can be emitted.
type AccessRecord struct {
ID uint32

Msg *clog.AccessMessage

RequestBytes int64
ResponseBytes int64

LastActivity int64

cancel context.CancelFunc
finished atomic.Bool
}

// ContextWithAccessRecord stores r in ctx for deferred access-log handling.
func ContextWithAccessRecord(ctx context.Context, r *AccessRecord) context.Context {
return context.WithValue(ctx, accessRecordKey{}, r)
}

// AccessRecordFromContext returns the access record stored in ctx, if any.
func AccessRecordFromContext(ctx context.Context) *AccessRecord {
r, _ := ctx.Value(accessRecordKey{}).(*AccessRecord)
return r
}

func (r *AccessRecord) touch() {
atomic.StoreInt64(&r.LastActivity, time.Now().UnixNano())
}

func (r *AccessRecord) addRequestBytes(n int64) {
if n <= 0 {
return
}
atomic.AddInt64(&r.RequestBytes, n)
r.touch()
}

func (r *AccessRecord) addResponseBytes(n int64) {
if n <= 0 {
return
}
atomic.AddInt64(&r.ResponseBytes, n)
r.touch()
}

func cloneAccessMessage(msg *clog.AccessMessage) *clog.AccessMessage {
if msg == nil {
return nil
}
cloned := *msg
return &cloned
}

func (m *Manager) completeAccessRecord(r *AccessRecord, reason error) {
if r == nil || !r.finished.CompareAndSwap(false, true) {
return
}
msg := cloneAccessMessage(r.Msg)
if msg == nil {
return
}
if reason != nil {
msg.Reason = reason
}
msg.RequestBytes = atomic.LoadInt64(&r.RequestBytes)
msg.ResponseBytes = atomic.LoadInt64(&r.ResponseBytes)
clog.Record(msg)
}

// NewAccessRecord creates a new access record for an accepted request.
func (m *Manager) NewAccessRecord(msg *clog.AccessMessage, cancel context.CancelFunc) *AccessRecord {
if msg == nil {
return nil
}
record := &AccessRecord{
ID: atomic.AddUint32(&m.globalNext, 1),
Msg: msg,
cancel: cancel,
}
record.touch()
return record
}

// FinishAccessRecord emits the final access log for r, including payload
// totals accumulated during the request lifetime.
func (m *Manager) FinishAccessRecord(r *AccessRecord) {
m.completeAccessRecord(r, nil)
}

// AbortAccessRecord emits the final access log for r with an abort reason and
// cancels the tracked context if one was supplied.
func (m *Manager) AbortAccessRecord(r *AccessRecord, reason error) {
if r != nil && r.cancel != nil {
r.cancel()
}
m.completeAccessRecord(r, reason)
}

// TrackAccessLink stores a deferred access record in ctx and wraps link so
// payload bytes can be accounted for at the body reader/writer boundary.
func (m *Manager) TrackAccessLink(ctx context.Context, msg *clog.AccessMessage, link *transport.Link, cancel context.CancelFunc) (context.Context, *transport.Link, *AccessRecord) {
record := m.NewAccessRecord(msg, cancel)
if record == nil || link == nil {
return ctx, link, record
}
ctx = ContextWithAccessRecord(ctx, record)
link = WrapAccessLink(link, record)
return ctx, link, record
}

// WrapAccessLink wraps link so payload bytes are attributed to record.
func WrapAccessLink(link *transport.Link, record *AccessRecord) *transport.Link {
if link == nil || record == nil {
return link
}
if link.Reader != nil {
link.Reader = &TrackedAccessReader{Reader: link.Reader, record: record}
}
if link.Writer != nil {
link.Writer = &TrackedAccessWriter{Writer: link.Writer, record: record}
}
return link
}

// TrackedAccessReader counts request payload bytes read by an accepted request.
type TrackedAccessReader struct {
buf.Reader
record *AccessRecord
}

func (r *TrackedAccessReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
mb, err := r.Reader.ReadMultiBuffer()
if n := int64(mb.Len()); n > 0 && r.record != nil {
r.record.addRequestBytes(n)
}
return mb, err
}

func (r *TrackedAccessReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
if reader, ok := r.Reader.(buf.TimeoutReader); ok {
mb, err := reader.ReadMultiBufferTimeout(timeout)
if n := int64(mb.Len()); n > 0 && r.record != nil {
r.record.addRequestBytes(n)
}
return mb, err
}
return r.ReadMultiBuffer()
}

func (r *TrackedAccessReader) Interrupt() {
common.Interrupt(r.Reader)
}

func (r *TrackedAccessReader) Close() error {
return common.Close(r.Reader)
}

// TrackedAccessWriter counts response payload bytes written by an accepted
// request.
type TrackedAccessWriter struct {
buf.Writer
record *AccessRecord
}

func (w *TrackedAccessWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
n := int64(mb.Len())
err := w.Writer.WriteMultiBuffer(mb)
if err == nil && n > 0 && w.record != nil {
w.record.addResponseBytes(n)
}
return err
}

func (w *TrackedAccessWriter) Close() error {
return common.Close(w.Writer)
}
105 changes: 105 additions & 0 deletions app/connectiontracker/command/command.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package command

import (
"context"
"strings"

grpc "google.golang.org/grpc"

"github.com/xtls/xray-core/app/connectiontracker"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/core"
)

type connTrackerServer struct {
UnimplementedConnTrackerServiceServer
manager *connectiontracker.Manager
}

func (s *connTrackerServer) ListConnections(_ context.Context, _ *ListConnectionsRequest) (*ListConnectionsResponse, error) {
all := s.manager.ListAllConnections()
resp := &ListConnectionsResponse{
Connections: make([]*ConnInfo, 0, len(all)),
}
for _, c := range all {
resp.Connections = append(resp.Connections, toProto(c))
}
return resp, nil
}

func (s *connTrackerServer) CloseConnection(_ context.Context, req *CloseConnectionRequest) (*CloseConnectionResponse, error) {
found := s.manager.CloseGlobalConn(req.Id)
return &CloseConnectionResponse{Found: found}, nil
}

func (s *connTrackerServer) GetUserStats(_ context.Context, req *GetUserStatsRequest) (*GetUserStatsResponse, error) {
up, down, count := s.manager.GetUserStats(strings.ToLower(req.Email))
return &GetUserStatsResponse{
Uplink: up,
Downlink: down,
ConnCount: count,
}, nil
}

func (s *connTrackerServer) StreamConnections(_ *StreamConnectionsRequest, stream grpc.ServerStreamingServer[ConnectionUpdate]) error {
ch := s.manager.Subscribe()
defer s.manager.Unsubscribe(ch)

ctx := stream.Context()
for {
select {
case <-ctx.Done():
return ctx.Err()
case ev, ok := <-ch:
if !ok {
return nil
}
evType := ConnEventType_DISCONNECTED
if ev.Connected {
evType = ConnEventType_CONNECTED
}
if err := stream.Send(&ConnectionUpdate{
Event: evType,
Conn: toProto(ev.Info),
}); err != nil {
return err
}
}
}
}

func toProto(c connectiontracker.ConnectionInfo) *ConnInfo {
return &ConnInfo{
Id: c.ID,
Email: c.Email,
InboundTag: c.InboundTag,
Protocol: c.Protocol,
StartTime: c.StartTime.Unix(),
LastActivity: c.LastActivity.Unix(),
Uplink: c.Uplink,
Downlink: c.Downlink,
}
}

type service struct {
manager *connectiontracker.Manager
}

func (s *service) Register(server *grpc.Server) {
RegisterConnTrackerServiceServer(server, &connTrackerServer{manager: s.manager})
}

func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, _ interface{}) (interface{}, error) {
s := new(service)

if err := core.RequireFeatures(ctx, func(trackerSvc connectiontracker.Feature) error {
s.manager = trackerSvc.Manager()
return nil
}); err != nil {
return nil, err
}

return s, nil
}))
}
Loading