From 4bd3ae52bd88f12079e2ce36b4d2fa12fa679339 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Mon, 10 Mar 2025 10:45:07 +0800 Subject: [PATCH] chore: dialer will consider the routing of the local interface when auto-detect-interface in tun is enabled for #1881 #1819 --- component/dialer/dialer.go | 21 +++++-- component/dialer/options.go | 7 +++ listener/listener.go | 3 - listener/sing_tun/server.go | 55 ++++++++++++------- transport/hysteria/conns/faketcp/tcp_linux.go | 10 +++- 5 files changed, 65 insertions(+), 31 deletions(-) diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 4fd051ef..da217f28 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -88,6 +88,15 @@ func ListenPacket(ctx context.Context, network, address string, rAddrPort netip. if DefaultSocketHook != nil { // ignore interfaceName, routingMark when DefaultSocketHook not null (in CMFA) socketHookToListenConfig(lc) } else { + if cfg.interfaceName == "" { + if finder := DefaultInterfaceFinder.Load(); finder != nil { + cfg.interfaceName = finder.FindInterfaceName(rAddrPort.Addr()) + } + } + if rAddrPort.Addr().Unmap().IsLoopback() { + // avoid "The requested address is not valid in its context." + cfg.interfaceName = "" + } if cfg.interfaceName != "" { bind := bindIfaceToListenConfig if cfg.fallbackBind { @@ -153,6 +162,11 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po if DefaultSocketHook != nil { // ignore interfaceName, routingMark and tfo when DefaultSocketHook not null (in CMFA) socketHookToToDialer(dialer) } else { + if opt.interfaceName == "" { + if finder := DefaultInterfaceFinder.Load(); finder != nil { + opt.interfaceName = finder.FindInterfaceName(destination) + } + } if opt.interfaceName != "" { bind := bindIfaceToDialer if opt.fallbackBind { @@ -373,12 +387,7 @@ func (d Dialer) DialContext(ctx context.Context, network, address string) (net.C } func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) { - opt := d.Opt // make a copy - if rAddrPort.Addr().Unmap().IsLoopback() { - // avoid "The requested address is not valid in its context." - WithInterface("")(&opt) - } - return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, rAddrPort, WithOption(opt)) + return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, rAddrPort, WithOption(d.Opt)) } func NewDialer(options ...Option) Dialer { diff --git a/component/dialer/options.go b/component/dialer/options.go index c0c21891..3da55ae6 100644 --- a/component/dialer/options.go +++ b/component/dialer/options.go @@ -3,6 +3,7 @@ package dialer import ( "context" "net" + "net/netip" "github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/component/resolver" @@ -12,8 +13,14 @@ var ( DefaultOptions []Option DefaultInterface = atomic.NewTypedValue[string]("") DefaultRoutingMark = atomic.NewInt32(0) + + DefaultInterfaceFinder = atomic.NewTypedValue[InterfaceFinder](nil) ) +type InterfaceFinder interface { + FindInterfaceName(destination netip.Addr) string +} + type NetDialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } diff --git a/listener/listener.go b/listener/listener.go index 2e25c8b8..cb8842dd 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -512,9 +512,6 @@ func ReCreateTun(tunConf LC.Tun, tunnel C.Tunnel) { }() if tunConf.Equal(LastTunConf) { - if tunLister != nil { - tunLister.FlushDefaultInterface() - } return } diff --git a/listener/sing_tun/server.go b/listener/sing_tun/server.go index ba337b01..56527ded 100644 --- a/listener/sing_tun/server.go +++ b/listener/sing_tun/server.go @@ -52,6 +52,8 @@ type Listener struct { autoRedirect tun.AutoRedirect autoRedirectOutputMark int32 + cDialerInterfaceFinder dialer.InterfaceFinder + ruleUpdateCallbackCloser io.Closer ruleUpdateMutex sync.Mutex routeAddressMap map[string]*netipx.IPSet @@ -290,13 +292,25 @@ func New(options LC.Tun, tunnel C.Tunnel, additions ...inbound.Addition) (l *Lis } l.defaultInterfaceMonitor = defaultInterfaceMonitor defaultInterfaceMonitor.RegisterCallback(func(event int) { - l.FlushDefaultInterface() + iface.FlushCache() + resolver.ResetConnection() // reset resolver's connection after default interface changed }) err = defaultInterfaceMonitor.Start() if err != nil { err = E.Cause(err, "start DefaultInterfaceMonitor") return } + + if options.AutoDetectInterface { + l.cDialerInterfaceFinder = &cDialerInterfaceFinder{ + tunName: tunName, + defaultInterfaceMonitor: defaultInterfaceMonitor, + } + if !dialer.DefaultInterfaceFinder.CompareAndSwap(nil, l.cDialerInterfaceFinder) { + err = E.New("don't allowed two tun listener using auto-detect-interface") + return + } + } } tunOptions := tun.Options{ @@ -503,27 +517,25 @@ func (l *Listener) updateRule(ruleProvider provider.RuleProvider, exclude bool, } } -func (l *Listener) FlushDefaultInterface() { - if l.options.AutoDetectInterface && l.defaultInterfaceMonitor != nil { - for _, destination := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified(), netip.MustParseAddr("1.1.1.1")} { - autoDetectInterfaceName := l.defaultInterfaceMonitor.DefaultInterfaceName(destination) - if autoDetectInterfaceName == l.tunName { - log.Warnln("[TUN] Auto detect interface by %s get same name with tun", destination.String()) - } else if autoDetectInterfaceName == "" || autoDetectInterfaceName == "" { - log.Warnln("[TUN] Auto detect interface by %s get empty name.", destination.String()) - } else { - if old := dialer.DefaultInterface.Swap(autoDetectInterfaceName); old != autoDetectInterfaceName { - log.Warnln("[TUN] default interface changed by monitor, %s => %s", old, autoDetectInterfaceName) - iface.FlushCache() - resolver.ResetConnection() // reset resolver's connection after default interface changed - } - return - } - } - if dialer.DefaultInterface.CompareAndSwap("", "") { - log.Warnln("[TUN] Auto detect interface failed, set '' to DefaultInterface to avoid lookback") +type cDialerInterfaceFinder struct { + tunName string + defaultInterfaceMonitor tun.DefaultInterfaceMonitor +} + +func (d *cDialerInterfaceFinder) FindInterfaceName(destination netip.Addr) string { + for _, dest := range []netip.Addr{destination, netip.IPv4Unspecified(), netip.IPv6Unspecified()} { + autoDetectInterfaceName := d.defaultInterfaceMonitor.DefaultInterfaceName(dest) + if autoDetectInterfaceName == d.tunName { + log.Warnln("[TUN] Auto detect interface for %s get same name with tun", destination.String()) + } else if autoDetectInterfaceName == "" || autoDetectInterfaceName == "" { + log.Warnln("[TUN] Auto detect interface for %s get empty name.", destination.String()) + } else { + log.Debugln("[TUN] Auto detect interface for %s --> %s", destination, autoDetectInterfaceName) + return autoDetectInterfaceName } } + log.Warnln("[TUN] Auto detect interface for %s failed, return '' to avoid lookback", destination) + return "" } func uidToRange(uidList []uint32) []ranges.Range[uint32] { @@ -564,6 +576,9 @@ func (l *Listener) Close() error { if l.autoRedirectOutputMark != 0 { dialer.DefaultRoutingMark.CompareAndSwap(l.autoRedirectOutputMark, 0) } + if l.cDialerInterfaceFinder != nil { + dialer.DefaultInterfaceFinder.CompareAndSwap(l.cDialerInterfaceFinder, nil) + } return common.Close( l.ruleUpdateCallbackCloser, l.tunStack, diff --git a/transport/hysteria/conns/faketcp/tcp_linux.go b/transport/hysteria/conns/faketcp/tcp_linux.go index 2aaaf139..fb59cf98 100644 --- a/transport/hysteria/conns/faketcp/tcp_linux.go +++ b/transport/hysteria/conns/faketcp/tcp_linux.go @@ -404,8 +404,14 @@ func Dial(network, address string) (*TCPConn, error) { var lTcpAddr *net.TCPAddr var lIpAddr *net.IPAddr - if ifaceName := dialer.DefaultInterface.Load(); len(ifaceName) > 0 { - rAddrPort := raddr.AddrPort() + rAddrPort := raddr.AddrPort() + ifaceName := dialer.DefaultInterface.Load() + if ifaceName == "" { + if finder := dialer.DefaultInterfaceFinder.Load(); finder != nil { + ifaceName = finder.FindInterfaceName(rAddrPort.Addr()) + } + } + if len(ifaceName) > 0 { addr, err := dialer.LookupLocalAddrFromIfaceName(ifaceName, network, rAddrPort.Addr(), int(rAddrPort.Port())) if err != nil { return nil, err