diff --git a/proxy/wireguard/server.go b/proxy/wireguard/server.go index 6144f5c75144..d51abf56d0ad 100644 --- a/proxy/wireguard/server.go +++ b/proxy/wireguard/server.go @@ -26,15 +26,15 @@ var nullDestination = net.TCPDestination(net.AnyIP, 0) type Server struct { bindServer *netBindServer - info routingInfo - policyManager policy.Manager + info routingInfo + policyManager policy.Manager + tag string + sniffingRequest session.SniffingRequest } type routingInfo struct { - ctx context.Context - dispatcher routing.Dispatcher - inboundTag *session.Inbound - contentTag *session.Content + ctx context.Context + dispatcher routing.Dispatcher } func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) { @@ -58,6 +58,14 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) { policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), } + // Retrieve tag and sniffing config from context (set by AlwaysOnInboundHandler) + if inbound := session.InboundFromContext(ctx); inbound != nil { + server.tag = inbound.Tag + } + if content := session.ContentFromContext(ctx); content != nil { + server.sniffingRequest = content.SniffingRequest + } + tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection) if err != nil { return nil, err @@ -81,8 +89,6 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con s.info = routingInfo{ ctx: ctx, dispatcher: dispatcher, - inboundTag: session.InboundFromContext(ctx), - contentTag: session.ContentFromContext(ctx), } ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String()) @@ -129,21 +135,21 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx)) sid := session.NewID() ctx = c.ContextWithID(ctx, sid) - inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs) - if s.info.inboundTag != nil { - inbound = *s.info.inboundTag + + inbound := session.Inbound{ + Name: "wireguard", + Tag: s.tag, + CanSpliceCopy: 3, + // overwrite the source to use the tun address for each sub context. + // Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context + // Currently we have no way to link to the original source address + Source: net.DestinationFromAddr(conn.RemoteAddr()), } - inbound.Name = "wireguard" - inbound.CanSpliceCopy = 3 - // overwrite the source to use the tun address for each sub context. - // Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context - // Currently we have no way to link to the original source address - inbound.Source = net.DestinationFromAddr(conn.RemoteAddr()) ctx = session.ContextWithInbound(ctx, &inbound) - if s.info.contentTag != nil { - ctx = session.ContextWithContent(ctx, s.info.contentTag) - } + ctx = session.ContextWithContent(ctx, &session.Content{ + SniffingRequest: s.sniffingRequest, + }) ctx = session.SubContextFromMuxInbound(ctx) plcy := s.policyManager.ForLevel(0)