From 161ef62d78d440adde0307af2822e6c394a58802 Mon Sep 17 00:00:00 2001 From: giveup Date: Wed, 11 Sep 2024 18:18:56 +0800 Subject: [PATCH] feat: add ecs for each domain query --- config/config.go | 112 +++++++++++++++++++++++++++++++++++------------ dns/client.go | 9 ++++ dns/dhcp.go | 9 ++++ dns/doh.go | 9 ++++ dns/doq.go | 9 ++++ dns/rcode.go | 8 ++++ dns/resolver.go | 17 ++++--- dns/system.go | 9 ++++ dns/util.go | 19 +++++++- docs/config.yaml | 10 +++++ 10 files changed, 174 insertions(+), 37 deletions(-) diff --git a/config/config.go b/config/config.go index 0af30b39..420a1fe3 100644 --- a/config/config.go +++ b/config/config.go @@ -9,6 +9,7 @@ import ( "net/url" "path" "regexp" + "strconv" "strings" "time" @@ -29,6 +30,7 @@ import ( tlsC "github.com/metacubex/mihomo/component/tls" "github.com/metacubex/mihomo/component/trie" "github.com/metacubex/mihomo/component/updater" + "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant" providerTypes "github.com/metacubex/mihomo/constant/provider" snifferTypes "github.com/metacubex/mihomo/constant/sniffer" @@ -41,6 +43,7 @@ import ( RP "github.com/metacubex/mihomo/rules/provider" T "github.com/metacubex/mihomo/tunnel" + D "github.com/miekg/dns" orderedmap "github.com/wk8/go-ordered-map/v2" "golang.org/x/exp/slices" "gopkg.in/yaml.v3" @@ -190,25 +193,27 @@ type Config struct { } type RawDNS struct { - Enable bool `yaml:"enable" json:"enable"` - PreferH3 bool `yaml:"prefer-h3" json:"prefer-h3"` - IPv6 bool `yaml:"ipv6" json:"ipv6"` - IPv6Timeout uint `yaml:"ipv6-timeout" json:"ipv6-timeout"` - UseHosts bool `yaml:"use-hosts" json:"use-hosts"` - UseSystemHosts bool `yaml:"use-system-hosts" json:"use-system-hosts"` - RespectRules bool `yaml:"respect-rules" json:"respect-rules"` - NameServer []string `yaml:"nameserver" json:"nameserver"` - Fallback []string `yaml:"fallback" json:"fallback"` - FallbackFilter RawFallbackFilter `yaml:"fallback-filter" json:"fallback-filter"` - Listen string `yaml:"listen" json:"listen"` - EnhancedMode C.DNSMode `yaml:"enhanced-mode" json:"enhanced-mode"` - FakeIPRange string `yaml:"fake-ip-range" json:"fake-ip-range"` - FakeIPFilter []string `yaml:"fake-ip-filter" json:"fake-ip-filter"` - FakeIPFilterMode C.FilterMode `yaml:"fake-ip-filter-mode" json:"fake-ip-filter-mode"` - DefaultNameserver []string `yaml:"default-nameserver" json:"default-nameserver"` - CacheAlgorithm string `yaml:"cache-algorithm" json:"cache-algorithm"` - NameServerPolicy *orderedmap.OrderedMap[string, any] `yaml:"nameserver-policy" json:"nameserver-policy"` - ProxyServerNameserver []string `yaml:"proxy-server-nameserver" json:"proxy-server-nameserver"` + Enable bool `yaml:"enable" json:"enable"` + PreferH3 bool `yaml:"prefer-h3" json:"prefer-h3"` + IPv6 bool `yaml:"ipv6" json:"ipv6"` + IPv6Timeout uint `yaml:"ipv6-timeout" json:"ipv6-timeout"` + UseHosts bool `yaml:"use-hosts" json:"use-hosts"` + UseSystemHosts bool `yaml:"use-system-hosts" json:"use-system-hosts"` + RespectRules bool `yaml:"respect-rules" json:"respect-rules"` + NameServer []string `yaml:"nameserver" json:"nameserver"` + Fallback []string `yaml:"fallback" json:"fallback"` + FallbackFilter RawFallbackFilter `yaml:"fallback-filter" json:"fallback-filter"` + Listen string `yaml:"listen" json:"listen"` + EnhancedMode C.DNSMode `yaml:"enhanced-mode" json:"enhanced-mode"` + FakeIPRange string `yaml:"fake-ip-range" json:"fake-ip-range"` + FakeIPFilter []string `yaml:"fake-ip-filter" json:"fake-ip-filter"` + FakeIPFilterMode C.FilterMode `yaml:"fake-ip-filter-mode" json:"fake-ip-filter-mode"` + DefaultNameserver []string `yaml:"default-nameserver" json:"default-nameserver"` + CacheAlgorithm string `yaml:"cache-algorithm" json:"cache-algorithm"` + NameServerPolicy *orderedmap.OrderedMap[string, any] `yaml:"nameserver-policy" json:"nameserver-policy"` + DefaultECS string `yaml:"default-ecs" json:"default-ecs"` + DomainECSPolicy *orderedmap.OrderedMap[string, string] `yaml:"domain-ecs-policy" json:"domain-ecs-policy"` + ProxyServerNameserver []string `yaml:"proxy-server-nameserver" json:"proxy-server-nameserver"` } type RawFallbackFilter struct { @@ -1285,7 +1290,7 @@ func parsePureDNSServer(server string) string { } } -func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], ruleProviders map[string]providerTypes.RuleProvider, respectRules bool, preferH3 bool) ([]dns.Policy, error) { +func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], ruleProviders map[string]providerTypes.RuleProvider, respectRules bool, preferH3 bool, ecsPolicy *orderedmap.OrderedMap[string, string]) ([]dns.Policy, error) { var policy []dns.Policy re := regexp.MustCompile(`[a-zA-Z0-9\-]+\.[a-zA-Z]{2,}(\.[a-zA-Z]{2,})?`) @@ -1306,7 +1311,7 @@ func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], rulePro subkeys = strings.Split(subkeys[0], ",") for _, subkey := range subkeys { newKey := "geosite:" + subkey - policy = append(policy, dns.Policy{Domain: newKey, NameServers: nameservers}) + policy = append(policy, dns.Policy{Domain: newKey, NameServers: nameservers, Subnet: parseECSIfPresent(k, ecsPolicy, "")}) } } else if strings.Contains(strings.ToLower(k), "rule-set:") { subkeys := strings.Split(k, ":") @@ -1314,21 +1319,21 @@ func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], rulePro subkeys = strings.Split(subkeys[0], ",") for _, subkey := range subkeys { newKey := "rule-set:" + subkey - policy = append(policy, dns.Policy{Domain: newKey, NameServers: nameservers}) + policy = append(policy, dns.Policy{Domain: newKey, NameServers: nameservers, Subnet: parseECSIfPresent(k, ecsPolicy, "")}) } } else if re.MatchString(k) { subkeys := strings.Split(k, ",") for _, subkey := range subkeys { - policy = append(policy, dns.Policy{Domain: subkey, NameServers: nameservers}) + policy = append(policy, dns.Policy{Domain: subkey, NameServers: nameservers, Subnet: parseECSIfPresent(k, ecsPolicy, "")}) } } } else { if strings.Contains(strings.ToLower(k), "geosite:") { - policy = append(policy, dns.Policy{Domain: "geosite:" + k[8:], NameServers: nameservers}) + policy = append(policy, dns.Policy{Domain: "geosite:" + k[8:], NameServers: nameservers, Subnet: parseECSIfPresent(k, ecsPolicy, "")}) } else if strings.Contains(strings.ToLower(k), "rule-set:") { - policy = append(policy, dns.Policy{Domain: "rule-set:" + k[9:], NameServers: nameservers}) + policy = append(policy, dns.Policy{Domain: "rule-set:" + k[9:], NameServers: nameservers, Subnet: parseECSIfPresent(k, ecsPolicy, "")}) } else { - policy = append(policy, dns.Policy{Domain: k, NameServers: nameservers}) + policy = append(policy, dns.Policy{Domain: k, NameServers: nameservers, Subnet: parseECSIfPresent(k, ecsPolicy, "")}) } } } @@ -1342,14 +1347,14 @@ func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], rulePro if err != nil { return nil, err } - policy[idx] = dns.Policy{Matcher: matcher, NameServers: nameservers} + policy[idx] = dns.Policy{Matcher: matcher, NameServers: nameservers, Subnet: policy[idx].Subnet} } else if strings.HasPrefix(domain, "geosite:") { country := domain[8:] matcher, err := RC.NewGEOSITE(country, "dns.nameserver-policy") if err != nil { return nil, err } - policy[idx] = dns.Policy{Matcher: matcher, NameServers: nameservers} + policy[idx] = dns.Policy{Matcher: matcher, NameServers: nameservers, Subnet: policy[idx].Subnet} } else { if _, valid := trie.ValidAndSplitDomain(domain); !valid { return nil, fmt.Errorf("DNS ResoverRule invalid domain: %s", domain) @@ -1360,6 +1365,54 @@ func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], rulePro return policy, nil } +// parseECSIfPresent defaultECS will be override by domain-ecs-policy +func parseECSIfPresent(key string, ecsPolicy *orderedmap.OrderedMap[string, string], defaultECS string) (subnet *D.OPT) { + configECS, present := ecsPolicy.Get(key) + if !present { + if defaultECS == "" { + return nil + } + configECS = defaultECS + } + edns := strings.Split(configECS, "/") + if len(edns) < 1 || len(edns) > 2 { + log.Warnln("ecs is invalid ", configECS) + return nil + } + + edns0 := new(D.OPT) + edns0.Hdr.Name = "." + edns0.Hdr.Rrtype = D.TypeOPT + //According to RFC 6891, the max UDP size should be 4096 bytes + const maxUdpSize = 4096 + edns0.SetUDPSize(maxUdpSize) + edns0Subnet := new(D.EDNS0_SUBNET) + edns0Subnet.Code = D.EDNS0SUBNET + netMask, err := strconv.Atoi(edns[1]) + //ipv4 Max Netmask 32, ipv6 Max Netmask 128 + if err != nil || netMask < 0 || netMask > 128 { + log.Warnln("ecs netmask range is either invalid or cannot be converted", netMask) + return nil + } + edns0Subnet.SourceNetmask = uint8(netMask) + edns0Subnet.SourceScope = 0 + ip := net.ParseIP(edns[0]) + + ecsIP := ip.To4() + if ecsIP == nil { + // ipv6 address,Family should be 2 + edns0Subnet.Family = 2 + edns0Subnet.Address = ip.To16() + edns0.Option = append(edns0.Option, edns0Subnet) + return edns0 + } + // ipv4 address,Family should be 1 + edns0Subnet.Family = 1 + edns0Subnet.Address = ecsIP + edns0.Option = append(edns0.Option, edns0Subnet) + return edns0 +} + func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], ruleProviders map[string]providerTypes.RuleProvider) (*DNS, error) { cfg := rawCfg.DNS if cfg.Enable && len(cfg.NameServer) == 0 { @@ -1388,9 +1441,10 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul return nil, err } - if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy, ruleProviders, cfg.RespectRules, cfg.PreferH3); err != nil { + if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy, ruleProviders, cfg.RespectRules, cfg.PreferH3, cfg.DomainECSPolicy); err != nil { return nil, err } + constant.DefaultECS = parseECSIfPresent("", cfg.DomainECSPolicy, cfg.DefaultECS) if dnsCfg.ProxyServerNameserver, err = parseNameServer(cfg.ProxyServerNameserver, false, cfg.PreferH3); err != nil { return nil, err diff --git a/dns/client.go b/dns/client.go index 096b96a7..5cf10fef 100644 --- a/dns/client.go +++ b/dns/client.go @@ -19,6 +19,15 @@ type client struct { host string dialer *dnsDialer addr string + subnet *D.OPT +} + +func (c *client) SetSubnet(subnet *D.OPT) { + c.subnet = subnet +} + +func (c *client) GetSubnet() *D.OPT { + return c.subnet } var _ dnsClient = (*client)(nil) diff --git a/dns/dhcp.go b/dns/dhcp.go index dc1344f5..0f3fdc50 100644 --- a/dns/dhcp.go +++ b/dns/dhcp.go @@ -30,6 +30,15 @@ type dhcpClient struct { done chan struct{} clients []dnsClient err error + subnet *D.OPT +} + +func (d *dhcpClient) SetSubnet(subnet *D.OPT) { + d.subnet = subnet +} + +func (d *dhcpClient) GetSubnet() *D.OPT { + return d.subnet } var _ dnsClient = (*dhcpClient)(nil) diff --git a/dns/doh.go b/dns/doh.go index ffb65fce..ba744cb3 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -70,6 +70,15 @@ type dnsOverHTTPS struct { skipCertVerify bool ecsPrefix netip.Prefix ecsOverride bool + subnet *D.OPT +} + +func (doh *dnsOverHTTPS) SetSubnet(subnet *D.OPT) { + doh.subnet = subnet +} + +func (doh *dnsOverHTTPS) GetSubnet() *D.OPT { + return doh.subnet } // type check diff --git a/dns/doq.go b/dns/doq.go index ad936f95..d4c1dc41 100644 --- a/dns/doq.go +++ b/dns/doq.go @@ -62,6 +62,15 @@ type dnsOverQUIC struct { addr string dialer *dnsDialer + subnet *D.OPT +} + +func (doq *dnsOverQUIC) SetSubnet(subnet *D.OPT) { + doq.subnet = subnet +} + +func (doq *dnsOverQUIC) GetSubnet() *D.OPT { + return doq.subnet } // type check diff --git a/dns/rcode.go b/dns/rcode.go index 9777d2e7..f76d76e5 100644 --- a/dns/rcode.go +++ b/dns/rcode.go @@ -37,6 +37,14 @@ type rcodeClient struct { addr string } +func (r rcodeClient) SetSubnet(subnet *D.OPT) { + +} + +func (r rcodeClient) GetSubnet() *D.OPT { + return nil +} + var _ dnsClient = rcodeClient{} func (r rcodeClient) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) { diff --git a/dns/resolver.go b/dns/resolver.go index e03feef4..50a0437e 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -24,6 +24,8 @@ import ( type dnsClient interface { ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) Address() string + SetSubnet(subnet *D.OPT) + GetSubnet() *D.OPT } type dnsCache interface { @@ -407,6 +409,7 @@ type Policy struct { Domain string Matcher C.DomainMatcher NameServers []NameServer + Subnet *D.OPT } type Config struct { @@ -442,11 +445,12 @@ func NewResolver(config Config) *Resolver { NameServer dnsClient } - cacheTransform := func(nameserver []NameServer) (result []dnsClient) { + cacheTransform := func(nameserver []NameServer, subnet *D.OPT) (result []dnsClient) { LOOP: for _, ns := range nameserver { for _, nsc := range nameServerCache { if nsc.NameServer.Equal(ns) { + nsc.dnsClient.SetSubnet(subnet) result = append(result, nsc.dnsClient) continue LOOP } @@ -455,6 +459,7 @@ func NewResolver(config Config) *Resolver { dc := transform([]NameServer{ns}, defaultResolver) if len(dc) > 0 { dc := dc[0] + dc.SetSubnet(subnet) nameServerCache = append(nameServerCache, struct { NameServer dnsClient @@ -472,18 +477,18 @@ func NewResolver(config Config) *Resolver { } r := &Resolver{ ipv6: config.IPv6, - main: cacheTransform(config.Main), + main: cacheTransform(config.Main, nil), cache: cache, hosts: config.Hosts, ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond, } if len(config.Fallback) != 0 { - r.fallback = cacheTransform(config.Fallback) + r.fallback = cacheTransform(config.Fallback, nil) } if len(config.ProxyServer) != 0 { - r.proxyServer = cacheTransform(config.ProxyServer) + r.proxyServer = cacheTransform(config.ProxyServer, nil) } if len(config.Policy) != 0 { @@ -503,12 +508,12 @@ func NewResolver(config Config) *Resolver { for _, policy := range config.Policy { if policy.Matcher != nil { - insertPolicy(domainMatcherPolicy{matcher: policy.Matcher, dnsClients: cacheTransform(policy.NameServers)}) + insertPolicy(domainMatcherPolicy{matcher: policy.Matcher, dnsClients: cacheTransform(policy.NameServers, policy.Subnet)}) } else { if triePolicy == nil { triePolicy = trie.New[[]dnsClient]() } - _ = triePolicy.Insert(policy.Domain, cacheTransform(policy.NameServers)) + _ = triePolicy.Insert(policy.Domain, cacheTransform(policy.NameServers, policy.Subnet)) } } insertPolicy(nil) diff --git a/dns/system.go b/dns/system.go index 944f2824..d7c4908f 100644 --- a/dns/system.go +++ b/dns/system.go @@ -24,6 +24,15 @@ type systemClient struct { mu sync.Mutex dnsClients map[string]*systemDnsClient lastFlush time.Time + subnet *D.OPT +} + +func (c *systemClient) SetSubnet(subnet *D.OPT) { + c.subnet = subnet +} + +func (c *systemClient) GetSubnet() *D.OPT { + return c.subnet } func (c *systemClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { diff --git a/dns/util.go b/dns/util.go index 92a86dbc..152f6267 100644 --- a/dns/util.go +++ b/dns/util.go @@ -14,6 +14,7 @@ import ( "github.com/metacubex/mihomo/common/picker" "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/resolver" + "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" D "github.com/miekg/dns" @@ -194,12 +195,14 @@ func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.M var noIpMsg *D.Msg for _, client := range clients { if _, isRCodeClient := client.(rcodeClient); isRCodeClient { + useSubnetIfPresent(client, m) msg, err = client.ExchangeContext(ctx, m) return msg, false, err } client := client // shadow define client to ensure the value captured by the closure will not be changed in the next loop fast.Go(func() (*D.Msg, error) { - log.Debugln("[DNS] resolve %s %s from %s", domain, qTypeStr, client.Address()) + useSubnetIfPresent(client, m) + log.Debugln("[DNS] resolve %s %s from %s with subnet %s", domain, qTypeStr, client.Address(), m.Extra) m, err := client.ExchangeContext(ctx, m) if err != nil { return nil, err @@ -209,7 +212,7 @@ func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.M return nil, errors.New("server failure: " + D.RcodeToString[m.Rcode]) } ips := msgToIP(m) - log.Debugln("[DNS] %s --> %s %s from %s", domain, ips, qTypeStr, client.Address()) + log.Debugln("[DNS] %s --> %s %s from %s with subnet %s", domain, ips, qTypeStr, client.Address(), m.Extra) switch qType { case D.TypeAAAA: if len(ips) == 0 { @@ -238,3 +241,15 @@ func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.M } return } + +func useSubnetIfPresent(client dnsClient, message *D.Msg) { + + if client.GetSubnet() == nil { + if message.Extra == nil { + message.Extra = append(message.Extra, constant.DefaultECS) + return + } + return + } + message.Extra = append(message.Extra, client.GetSubnet()) +} diff --git a/docs/config.yaml b/docs/config.yaml index 78e62d12..9f23df5c 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -316,6 +316,16 @@ dns: ## global,dns 为 rule-providers 中的名为 global 和 dns 规则订阅, ## 且 behavior 必须为 domain/classical,当为 classical 时仅会生效域名类规则 # "rule-set:global,dns": 8.8.8.8 +#默认的ecs + default-ecs: + "114.114.114.114/24" + domain-ecs-policy: + #配置指定域名要使用的ecs,必须和nameserver-policy保持一致,优先级大于默认ecs + #此处的域名需要和nameserver-policy配置的保持一致,即在nameserver-policy配置了多少个域名,此处也需要配置同样多个 + "geosite:tiktok": + "8.8.8.8.8/24" + "geosite:apple": + "223.5.5.5/24" proxies: # socks5 - name: "socks"