From f6e045b16116da523151ca219a72dd1cd26d69e3 Mon Sep 17 00:00:00 2001 From: ForestL18 <45709305+ForestL18@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:55:35 +0800 Subject: [PATCH] Make DnsHijack hijack specified protocols as expected --- listener/listener.go | 34 +++++++++++++++++----------------- listener/sing_tun/dns.go | 34 +++++++++++++++++++++------------- listener/sing_tun/server.go | 32 ++++++++++++++++++++++++-------- 3 files changed, 62 insertions(+), 38 deletions(-) diff --git a/listener/listener.go b/listener/listener.go index 2e25c8b8..32a7121b 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -37,11 +37,11 @@ var ( tproxyListener *tproxy.Listener tproxyUDPListener *tproxy.UDPListener mixedListener *mixed.Listener - mixedUDPLister *socks.UDPListener + mixedUDPListener *socks.UDPListener tunnelTCPListeners = map[string]*LT.Listener{} tunnelUDPListeners = map[string]*LT.PacketConn{} inboundListeners = map[string]C.InboundListener{} - tunLister *sing_tun.Listener + tunListener *sing_tun.Listener shadowSocksListener C.MultiAddrListener vmessListener *sing_vmess.Listener tuicListener *tuic.Listener @@ -74,10 +74,10 @@ type Ports struct { } func GetTunConf() LC.Tun { - if tunLister == nil { + if tunListener == nil { return LastTunConf } - return tunLister.Config() + return tunListener.Config() } func GetTuicConf() LC.TuicServer { @@ -463,10 +463,10 @@ func ReCreateMixed(port int, tunnel C.Tunnel) { shouldTCPIgnore = true } } - if mixedUDPLister != nil { - if mixedUDPLister.RawAddress() != addr { - mixedUDPLister.Close() - mixedUDPLister = nil + if mixedUDPListener != nil { + if mixedUDPListener.RawAddress() != addr { + mixedUDPListener.Close() + mixedUDPListener = nil } else { shouldUDPIgnore = true } @@ -485,7 +485,7 @@ func ReCreateMixed(port int, tunnel C.Tunnel) { return } - mixedUDPLister, err = socks.NewUDP(addr, tunnel) + mixedUDPListener, err = socks.NewUDP(addr, tunnel) if err != nil { mixedListener.Close() return @@ -512,8 +512,8 @@ func ReCreateTun(tunConf LC.Tun, tunnel C.Tunnel) { }() if tunConf.Equal(LastTunConf) { - if tunLister != nil { - tunLister.FlushDefaultInterface() + if tunListener != nil { + tunListener.FlushDefaultInterface() } return } @@ -524,13 +524,13 @@ func ReCreateTun(tunConf LC.Tun, tunnel C.Tunnel) { return } - lister, err := sing_tun.New(tunConf, tunnel) + listener, err := sing_tun.New(tunConf, tunnel) if err != nil { return } - tunLister = lister + tunListener = listener - log.Infoln("[TUN] Tun adapter listening at: %s", tunLister.Address()) + log.Infoln("[TUN] Tun adapter listening at: %s", tunListener.Address()) } func PatchTunnel(tunnels []LC.Tunnel, tunnel C.Tunnel) { @@ -716,9 +716,9 @@ func genAddr(host string, port int, allowLan bool) string { } func closeTunListener() { - if tunLister != nil { - tunLister.Close() - tunLister = nil + if tunListener != nil { + tunListener.Close() + tunListener = nil } } diff --git a/listener/sing_tun/dns.go b/listener/sing_tun/dns.go index 505f16ac..2f220676 100644 --- a/listener/sing_tun/dns.go +++ b/listener/sing_tun/dns.go @@ -20,23 +20,31 @@ import ( type ListenerHandler struct { *sing.ListenerHandler - DnsAdds []netip.AddrPort + DnsAdds map[string][]netip.AddrPort } -func (h *ListenerHandler) ShouldHijackDns(targetAddr netip.AddrPort) bool { - if targetAddr.Addr().IsLoopback() && targetAddr.Port() == 53 { // cause by system stack - return true - } - for _, addrPort := range h.DnsAdds { - if addrPort == targetAddr || (addrPort.Addr().IsUnspecified() && targetAddr.Port() == 53) { - return true - } - } - return false +func (h *ListenerHandler) ShouldHijackDns(targetAddr netip.AddrPort, protocol string) bool { + if targetAddr.Addr().IsLoopback() && targetAddr.Port() == 53 { + return true + } + + for proto, addrPorts := range h.DnsAdds { + if proto != protocol && proto != "all" { + continue + } + + for _, addrPort := range addrPorts { + if addrPort == targetAddr || (addrPort.Addr().IsUnspecified() && targetAddr.Port() == 53) { + return true + } + } + } + + return false } func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - if h.ShouldHijackDns(metadata.Destination.AddrPort()) { + if h.ShouldHijackDns(metadata.Destination.AddrPort(), "tcp") { log.Debugln("[DNS] hijack tcp:%s", metadata.Destination.String()) return resolver.RelayDnsConn(ctx, conn, resolver.DefaultDnsReadTimeout) } @@ -44,7 +52,7 @@ func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, meta } func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata M.Metadata) error { - if h.ShouldHijackDns(metadata.Destination.AddrPort()) { + if h.ShouldHijackDns(metadata.Destination.AddrPort(), "udp") { log.Debugln("[DNS] hijack udp:%s from %s", metadata.Destination.String(), metadata.Source.String()) defer func() { _ = conn.Close() }() mutex := sync.Mutex{} diff --git a/listener/sing_tun/server.go b/listener/sing_tun/server.go index ba337b01..dad05032 100644 --- a/listener/sing_tun/server.go +++ b/listener/sing_tun/server.go @@ -209,31 +209,47 @@ func New(options LC.Tun, tunnel C.Tunnel, additions ...inbound.Addition) (l *Lis } } - var dnsAdds []netip.AddrPort + var dnsAdds map[string][]netip.AddrPort = make(map[string][]netip.AddrPort) + + supportedProtocols := map[string]bool{ + "all": true, + "tcp": true, + "udp": true, + } for _, d := range options.DNSHijack { - if _, after, ok := strings.Cut(d, "://"); ok { - d = after + protocol := "all" + address := d + + if parts := strings.SplitN(d, "://", 2); len(parts) == 2 { + protocol = parts[0] + address = parts[1] } - d = strings.Replace(d, "any", "0.0.0.0", 1) - addrPort, err := netip.ParseAddrPort(d) + + if !supportedProtocols[protocol] { + return nil, fmt.Errorf("unsupported dns-hijack protocol: %s", protocol) + } + + address = strings.Replace(address, "any", "0.0.0.0", 1) + + addrPort, err := netip.ParseAddrPort(address) if err != nil { return nil, fmt.Errorf("parse dns-hijack url error: %w", err) } - dnsAdds = append(dnsAdds, addrPort) + dnsAdds[protocol] = append(dnsAdds[protocol], addrPort) } var dnsServerIp []string for _, a := range options.Inet4Address { addrPort := netip.AddrPortFrom(a.Addr().Next(), 53) dnsServerIp = append(dnsServerIp, a.Addr().Next().String()) - dnsAdds = append(dnsAdds, addrPort) + dnsAdds["all"] = append(dnsAdds["all"], addrPort) } for _, a := range options.Inet6Address { addrPort := netip.AddrPortFrom(a.Addr().Next(), 53) dnsServerIp = append(dnsServerIp, a.Addr().Next().String()) - dnsAdds = append(dnsAdds, addrPort) + dnsAdds["all"] = append(dnsAdds["all"], addrPort) } h, err := sing.NewListenerHandler(sing.ListenerConfig{