Make DnsHijack hijack specified protocols as expected

This commit is contained in:
ForestL18 2024-12-12 16:55:35 +08:00
parent 5d9d8f4d3b
commit f6e045b161
3 changed files with 62 additions and 38 deletions

View file

@ -37,11 +37,11 @@ var (
tproxyListener *tproxy.Listener
tproxyUDPListener *tproxy.UDPListener
mixedListener *mixed.Listener
mixedUDPLister *socks.UDPListener
mixedUDPListener *socks.UDPListener
tunnelTCPListeners = map[string]*LT.Listener{}
tunnelUDPListeners = map[string]*LT.PacketConn{}
inboundListeners = map[string]C.InboundListener{}
tunLister *sing_tun.Listener
tunListener *sing_tun.Listener
shadowSocksListener C.MultiAddrListener
vmessListener *sing_vmess.Listener
tuicListener *tuic.Listener
@ -74,10 +74,10 @@ type Ports struct {
}
func GetTunConf() LC.Tun {
if tunLister == nil {
if tunListener == nil {
return LastTunConf
}
return tunLister.Config()
return tunListener.Config()
}
func GetTuicConf() LC.TuicServer {
@ -463,10 +463,10 @@ func ReCreateMixed(port int, tunnel C.Tunnel) {
shouldTCPIgnore = true
}
}
if mixedUDPLister != nil {
if mixedUDPLister.RawAddress() != addr {
mixedUDPLister.Close()
mixedUDPLister = nil
if mixedUDPListener != nil {
if mixedUDPListener.RawAddress() != addr {
mixedUDPListener.Close()
mixedUDPListener = nil
} else {
shouldUDPIgnore = true
}
@ -485,7 +485,7 @@ func ReCreateMixed(port int, tunnel C.Tunnel) {
return
}
mixedUDPLister, err = socks.NewUDP(addr, tunnel)
mixedUDPListener, err = socks.NewUDP(addr, tunnel)
if err != nil {
mixedListener.Close()
return
@ -512,8 +512,8 @@ func ReCreateTun(tunConf LC.Tun, tunnel C.Tunnel) {
}()
if tunConf.Equal(LastTunConf) {
if tunLister != nil {
tunLister.FlushDefaultInterface()
if tunListener != nil {
tunListener.FlushDefaultInterface()
}
return
}
@ -524,13 +524,13 @@ func ReCreateTun(tunConf LC.Tun, tunnel C.Tunnel) {
return
}
lister, err := sing_tun.New(tunConf, tunnel)
listener, err := sing_tun.New(tunConf, tunnel)
if err != nil {
return
}
tunLister = lister
tunListener = listener
log.Infoln("[TUN] Tun adapter listening at: %s", tunLister.Address())
log.Infoln("[TUN] Tun adapter listening at: %s", tunListener.Address())
}
func PatchTunnel(tunnels []LC.Tunnel, tunnel C.Tunnel) {
@ -716,9 +716,9 @@ func genAddr(host string, port int, allowLan bool) string {
}
func closeTunListener() {
if tunLister != nil {
tunLister.Close()
tunLister = nil
if tunListener != nil {
tunListener.Close()
tunListener = nil
}
}

View file

@ -20,23 +20,31 @@ import (
type ListenerHandler struct {
*sing.ListenerHandler
DnsAdds []netip.AddrPort
DnsAdds map[string][]netip.AddrPort
}
func (h *ListenerHandler) ShouldHijackDns(targetAddr netip.AddrPort) bool {
if targetAddr.Addr().IsLoopback() && targetAddr.Port() == 53 { // cause by system stack
return true
}
for _, addrPort := range h.DnsAdds {
if addrPort == targetAddr || (addrPort.Addr().IsUnspecified() && targetAddr.Port() == 53) {
return true
}
}
return false
func (h *ListenerHandler) ShouldHijackDns(targetAddr netip.AddrPort, protocol string) bool {
if targetAddr.Addr().IsLoopback() && targetAddr.Port() == 53 {
return true
}
for proto, addrPorts := range h.DnsAdds {
if proto != protocol && proto != "all" {
continue
}
for _, addrPort := range addrPorts {
if addrPort == targetAddr || (addrPort.Addr().IsUnspecified() && targetAddr.Port() == 53) {
return true
}
}
}
return false
}
func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
if h.ShouldHijackDns(metadata.Destination.AddrPort()) {
if h.ShouldHijackDns(metadata.Destination.AddrPort(), "tcp") {
log.Debugln("[DNS] hijack tcp:%s", metadata.Destination.String())
return resolver.RelayDnsConn(ctx, conn, resolver.DefaultDnsReadTimeout)
}
@ -44,7 +52,7 @@ func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, meta
}
func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata M.Metadata) error {
if h.ShouldHijackDns(metadata.Destination.AddrPort()) {
if h.ShouldHijackDns(metadata.Destination.AddrPort(), "udp") {
log.Debugln("[DNS] hijack udp:%s from %s", metadata.Destination.String(), metadata.Source.String())
defer func() { _ = conn.Close() }()
mutex := sync.Mutex{}

View file

@ -209,31 +209,47 @@ func New(options LC.Tun, tunnel C.Tunnel, additions ...inbound.Addition) (l *Lis
}
}
var dnsAdds []netip.AddrPort
var dnsAdds map[string][]netip.AddrPort = make(map[string][]netip.AddrPort)
supportedProtocols := map[string]bool{
"all": true,
"tcp": true,
"udp": true,
}
for _, d := range options.DNSHijack {
if _, after, ok := strings.Cut(d, "://"); ok {
d = after
protocol := "all"
address := d
if parts := strings.SplitN(d, "://", 2); len(parts) == 2 {
protocol = parts[0]
address = parts[1]
}
d = strings.Replace(d, "any", "0.0.0.0", 1)
addrPort, err := netip.ParseAddrPort(d)
if !supportedProtocols[protocol] {
return nil, fmt.Errorf("unsupported dns-hijack protocol: %s", protocol)
}
address = strings.Replace(address, "any", "0.0.0.0", 1)
addrPort, err := netip.ParseAddrPort(address)
if err != nil {
return nil, fmt.Errorf("parse dns-hijack url error: %w", err)
}
dnsAdds = append(dnsAdds, addrPort)
dnsAdds[protocol] = append(dnsAdds[protocol], addrPort)
}
var dnsServerIp []string
for _, a := range options.Inet4Address {
addrPort := netip.AddrPortFrom(a.Addr().Next(), 53)
dnsServerIp = append(dnsServerIp, a.Addr().Next().String())
dnsAdds = append(dnsAdds, addrPort)
dnsAdds["all"] = append(dnsAdds["all"], addrPort)
}
for _, a := range options.Inet6Address {
addrPort := netip.AddrPortFrom(a.Addr().Next(), 53)
dnsServerIp = append(dnsServerIp, a.Addr().Next().String())
dnsAdds = append(dnsAdds, addrPort)
dnsAdds["all"] = append(dnsAdds["all"], addrPort)
}
h, err := sing.NewListenerHandler(sing.ListenerConfig{