From b76737bdbbce7cb2e6b380e33b77e5cb37c70c40 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=AE=8B=E8=BE=B0=E6=96=87?= <me@songchenwen.com>
Date: Sun, 15 Sep 2019 13:36:45 +0800
Subject: [PATCH] Feature: add fallback filters (#105)

---
 README.md                |  4 +++
 config/config.go         | 67 +++++++++++++++++++++++++++++++---------
 dns/filters.go           | 26 ++++++++++++++++
 dns/resolver.go          | 58 +++++++++++++++++++++++-----------
 hub/executor/executor.go |  4 +++
 5 files changed, 127 insertions(+), 32 deletions(-)
 create mode 100644 dns/filters.go

diff --git a/README.md b/README.md
index 9f692a7a..1af142fa 100644
--- a/README.md
+++ b/README.md
@@ -138,6 +138,10 @@ experimental:
   #   - https://1.1.1.1/dns-query # dns over https
   # fallback: # concurrent request with nameserver, fallback used when GEOIP country isn't CN
   #   - tcp://1.1.1.1
+  # fallback-filter:
+  #   geoip: true # default
+  #   ipcidr: # ips in these subnets will be considered polluted
+  #     - 240.0.0.0/4
 
 Proxy:
 
diff --git a/config/config.go b/config/config.go
index 6d0942dd..16445111 100644
--- a/config/config.go
+++ b/config/config.go
@@ -40,13 +40,20 @@ type General struct {
 
 // DNS config
 type DNS struct {
-	Enable       bool             `yaml:"enable"`
-	IPv6         bool             `yaml:"ipv6"`
-	NameServer   []dns.NameServer `yaml:"nameserver"`
-	Fallback     []dns.NameServer `yaml:"fallback"`
-	Listen       string           `yaml:"listen"`
-	EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
-	FakeIPRange  *fakeip.Pool
+	Enable         bool             `yaml:"enable"`
+	IPv6           bool             `yaml:"ipv6"`
+	NameServer     []dns.NameServer `yaml:"nameserver"`
+	Fallback       []dns.NameServer `yaml:"fallback"`
+	FallbackFilter FallbackFilter   `yaml:"fallback-filter"`
+	Listen         string           `yaml:"listen"`
+	EnhancedMode   dns.EnhancedMode `yaml:"enhanced-mode"`
+	FakeIPRange    *fakeip.Pool
+}
+
+// FallbackFilter config
+type FallbackFilter struct {
+	GeoIP  bool         `yaml:"geoip"`
+	IPCIDR []*net.IPNet `yaml:"ipcidr"`
 }
 
 // Experimental config
@@ -66,13 +73,19 @@ type Config struct {
 }
 
 type rawDNS struct {
-	Enable       bool             `yaml:"enable"`
-	IPv6         bool             `yaml:"ipv6"`
-	NameServer   []string         `yaml:"nameserver"`
-	Fallback     []string         `yaml:"fallback"`
-	Listen       string           `yaml:"listen"`
-	EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
-	FakeIPRange  string           `yaml:"fake-ip-range"`
+	Enable         bool              `yaml:"enable"`
+	IPv6           bool              `yaml:"ipv6"`
+	NameServer     []string          `yaml:"nameserver"`
+	Fallback       []string          `yaml:"fallback"`
+	FallbackFilter rawFallbackFilter `yaml:"fallback-filter"`
+	Listen         string            `yaml:"listen"`
+	EnhancedMode   dns.EnhancedMode  `yaml:"enhanced-mode"`
+	FakeIPRange    string            `yaml:"fake-ip-range"`
+}
+
+type rawFallbackFilter struct {
+	GeoIP  bool     `yaml:"geoip"`
+	IPCIDR []string `yaml:"ipcidr"`
 }
 
 type rawConfig struct {
@@ -145,6 +158,10 @@ func readConfig(path string) (*rawConfig, error) {
 		DNS: rawDNS{
 			Enable:      false,
 			FakeIPRange: "198.18.0.1/16",
+			FallbackFilter: rawFallbackFilter{
+				GeoIP:  true,
+				IPCIDR: []string{},
+			},
 		},
 	}
 	err = yaml.Unmarshal([]byte(data), &rawConfig)
@@ -545,6 +562,20 @@ func parseNameServer(servers []string) ([]dns.NameServer, error) {
 	return nameservers, nil
 }
 
+func parseFallbackIPCIDR(ips []string) ([]*net.IPNet, error) {
+	ipNets := []*net.IPNet{}
+
+	for idx, ip := range ips {
+		_, ipnet, err := net.ParseCIDR(ip)
+		if err != nil {
+			return nil, fmt.Errorf("DNS FallbackIP[%d] format error: %s", idx, err.Error())
+		}
+		ipNets = append(ipNets, ipnet)
+	}
+
+	return ipNets, nil
+}
+
 func parseDNS(cfg rawDNS) (*DNS, error) {
 	if cfg.Enable && len(cfg.NameServer) == 0 {
 		return nil, fmt.Errorf("If DNS configuration is turned on, NameServer cannot be empty")
@@ -555,6 +586,9 @@ func parseDNS(cfg rawDNS) (*DNS, error) {
 		Listen:       cfg.Listen,
 		IPv6:         cfg.IPv6,
 		EnhancedMode: cfg.EnhancedMode,
+		FallbackFilter: FallbackFilter{
+			IPCIDR: []*net.IPNet{},
+		},
 	}
 	var err error
 	if dnsCfg.NameServer, err = parseNameServer(cfg.NameServer); err != nil {
@@ -578,6 +612,11 @@ func parseDNS(cfg rawDNS) (*DNS, error) {
 		dnsCfg.FakeIPRange = pool
 	}
 
+	dnsCfg.FallbackFilter.GeoIP = cfg.FallbackFilter.GeoIP
+	if fallbackip, err := parseFallbackIPCIDR(cfg.FallbackFilter.IPCIDR); err == nil {
+		dnsCfg.FallbackFilter.IPCIDR = fallbackip
+	}
+
 	return dnsCfg, nil
 }
 
diff --git a/dns/filters.go b/dns/filters.go
new file mode 100644
index 00000000..abb99f47
--- /dev/null
+++ b/dns/filters.go
@@ -0,0 +1,26 @@
+package dns
+
+import "net"
+
+type fallbackFilter interface {
+	Match(net.IP) bool
+}
+
+type geoipFilter struct{}
+
+func (gf *geoipFilter) Match(ip net.IP) bool {
+	if mmdb == nil {
+		return false
+	}
+
+	record, _ := mmdb.Country(ip)
+	return record.Country.IsoCode == "CN" || record.Country.IsoCode == ""
+}
+
+type ipnetFilter struct {
+	ipnet *net.IPNet
+}
+
+func (inf *ipnetFilter) Match(ip net.IP) bool {
+	return inf.ipnet.Contains(ip)
+}
diff --git a/dns/resolver.go b/dns/resolver.go
index 6ed65764..a9cd283e 100644
--- a/dns/resolver.go
+++ b/dns/resolver.go
@@ -46,14 +46,15 @@ type result struct {
 }
 
 type Resolver struct {
-	ipv6     bool
-	mapping  bool
-	fakeip   bool
-	pool     *fakeip.Pool
-	fallback []resolver
-	main     []resolver
-	group    singleflight.Group
-	cache    *cache.Cache
+	ipv6            bool
+	mapping         bool
+	fakeip          bool
+	pool            *fakeip.Pool
+	main            []resolver
+	fallback        []resolver
+	fallbackFilters []fallbackFilter
+	group           singleflight.Group
+	cache           *cache.Cache
 }
 
 // ResolveIP request with TypeA and TypeAAAA, priority return TypeAAAA
@@ -94,6 +95,15 @@ func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) {
 	return r.resolveIP(host, D.TypeAAAA)
 }
 
+func (r *Resolver) shouldFallback(ip net.IP) bool {
+	for _, filter := range r.fallbackFilters {
+		if filter.Match(ip) {
+			return true
+		}
+	}
+	return false
+}
+
 // Exchange a batch of dns request, and it use cache
 func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) {
 	if len(m.Question) == 0 {
@@ -195,13 +205,8 @@ func (r *Resolver) fallbackExchange(m *D.Msg) (msg *D.Msg, err error) {
 	fallbackMsg := r.asyncExchange(r.fallback, m)
 	res := <-msgCh
 	if res.Error == nil {
-		if mmdb == nil {
-			return nil, errors.New("GeoIP cannot use")
-		}
-
 		if ips := r.msgToIP(res.Msg); len(ips) != 0 {
-			if record, _ := mmdb.Country(ips[0]); record.Country.IsoCode == "CN" || record.Country.IsoCode == "" {
-				// release channel
+			if r.shouldFallback(ips[0]) {
 				go func() { <-fallbackMsg }()
 				msg = res.Msg
 				return msg, err
@@ -272,18 +277,20 @@ type NameServer struct {
 	Addr string
 }
 
+type FallbackFilter struct {
+	GeoIP  bool
+	IPCIDR []*net.IPNet
+}
+
 type Config struct {
 	Main, Fallback []NameServer
 	IPv6           bool
 	EnhancedMode   EnhancedMode
+	FallbackFilter FallbackFilter
 	Pool           *fakeip.Pool
 }
 
 func New(config Config) *Resolver {
-	once.Do(func() {
-		mmdb, _ = geoip2.Open(C.Path.MMDB())
-	})
-
 	r := &Resolver{
 		ipv6:    config.IPv6,
 		main:    transform(config.Main),
@@ -292,8 +299,23 @@ func New(config Config) *Resolver {
 		fakeip:  config.EnhancedMode == FAKEIP,
 		pool:    config.Pool,
 	}
+
 	if len(config.Fallback) != 0 {
 		r.fallback = transform(config.Fallback)
 	}
+
+	fallbackFilters := []fallbackFilter{}
+	if config.FallbackFilter.GeoIP {
+		once.Do(func() {
+			mmdb, _ = geoip2.Open(C.Path.MMDB())
+		})
+
+		fallbackFilters = append(fallbackFilters, &geoipFilter{})
+	}
+	for _, ipnet := range config.FallbackFilter.IPCIDR {
+		fallbackFilters = append(fallbackFilters, &ipnetFilter{ipnet: ipnet})
+	}
+	r.fallbackFilters = fallbackFilters
+
 	return r
 }
diff --git a/hub/executor/executor.go b/hub/executor/executor.go
index 2883d5fc..321d6830 100644
--- a/hub/executor/executor.go
+++ b/hub/executor/executor.go
@@ -72,6 +72,10 @@ func updateDNS(c *config.DNS) {
 		IPv6:         c.IPv6,
 		EnhancedMode: c.EnhancedMode,
 		Pool:         c.FakeIPRange,
+		FallbackFilter: dns.FallbackFilter{
+			GeoIP:  c.FallbackFilter.GeoIP,
+			IPCIDR: c.FallbackFilter.IPCIDR,
+		},
 	})
 	dns.DefaultResolver = r
 	if err := dns.ReCreateServer(c.Listen, r); err != nil {