From c830b8aaf7a8889e51a30972ad7b3a0e13e23a34 Mon Sep 17 00:00:00 2001
From: wwqgtxx <wwqgtxx@gmail.com>
Date: Sun, 28 Jul 2024 10:07:37 +0800
Subject: [PATCH] feat: support convert `mrs` format back to `text` format

---
 component/cidr/ipcidr_set.go      | 10 ++++++++
 component/trie/domain.go          | 19 +++++++++++-----
 component/trie/domain_set.go      | 38 ++++++++++++++++++++++++++++++-
 component/trie/domain_set_test.go | 20 ++++++++++++++++
 component/trie/domain_test.go     |  3 ++-
 docs/config.yaml                  | 15 ++++++++----
 rules/provider/domain_strategy.go | 22 ++++++++++++++++++
 rules/provider/ipcidr_strategy.go |  9 ++++++++
 rules/provider/mrs_converter.go   | 12 ++++++++++
 rules/provider/provider.go        |  1 +
 10 files changed, 137 insertions(+), 12 deletions(-)

diff --git a/component/cidr/ipcidr_set.go b/component/cidr/ipcidr_set.go
index 521fabab..49071460 100644
--- a/component/cidr/ipcidr_set.go
+++ b/component/cidr/ipcidr_set.go
@@ -57,6 +57,16 @@ func (set *IpCidrSet) Merge() error {
 	return nil
 }
 
+func (set *IpCidrSet) Foreach(f func(prefix netip.Prefix) bool) {
+	for _, r := range set.rr {
+		for _, prefix := range r.Prefixes() {
+			if !f(prefix) {
+				return
+			}
+		}
+	}
+}
+
 // ToIPSet not safe convert to *netipx.IPSet
 // be careful, must be used after Merge
 func (set *IpCidrSet) ToIPSet() *netipx.IPSet {
diff --git a/component/trie/domain.go b/component/trie/domain.go
index 3decbb02..db30402e 100644
--- a/component/trie/domain.go
+++ b/component/trie/domain.go
@@ -123,16 +123,18 @@ func (t *DomainTrie[T]) Optimize() {
 	t.root.optimize()
 }
 
-func (t *DomainTrie[T]) Foreach(print func(domain string, data T)) {
+func (t *DomainTrie[T]) Foreach(fn func(domain string, data T) bool) {
 	for key, data := range t.root.getChildren() {
-		recursion([]string{key}, data, print)
+		recursion([]string{key}, data, fn)
 		if data != nil && data.inited {
-			print(joinDomain([]string{key}), data.data)
+			if !fn(joinDomain([]string{key}), data.data) {
+				return
+			}
 		}
 	}
 }
 
-func recursion[T any](items []string, node *Node[T], fn func(domain string, data T)) {
+func recursion[T any](items []string, node *Node[T], fn func(domain string, data T) bool) bool {
 	for key, data := range node.getChildren() {
 		newItems := append([]string{key}, items...)
 		if data != nil && data.inited {
@@ -140,10 +142,15 @@ func recursion[T any](items []string, node *Node[T], fn func(domain string, data
 			if domain[0] == domainStepByte {
 				domain = complexWildcard + domain
 			}
-			fn(domain, data.Data())
+			if !fn(domain, data.Data()) {
+				return false
+			}
+		}
+		if !recursion(newItems, data, fn) {
+			return false
 		}
-		recursion(newItems, data, fn)
 	}
+	return true
 }
 
 func joinDomain(items []string) string {
diff --git a/component/trie/domain_set.go b/component/trie/domain_set.go
index 860d1235..7778d133 100644
--- a/component/trie/domain_set.go
+++ b/component/trie/domain_set.go
@@ -28,8 +28,9 @@ type qElt struct{ s, e, col int }
 // NewDomainSet creates a new *DomainSet struct, from a DomainTrie.
 func (t *DomainTrie[T]) NewDomainSet() *DomainSet {
 	reserveDomains := make([]string, 0)
-	t.Foreach(func(domain string, data T) {
+	t.Foreach(func(domain string, data T) bool {
 		reserveDomains = append(reserveDomains, utils.Reverse(domain))
+		return true
 	})
 	// ensure that the same prefix is continuous
 	// and according to the ascending sequence of length
@@ -136,6 +137,41 @@ func (ss *DomainSet) Has(key string) bool {
 
 }
 
+func (ss *DomainSet) keys(f func(key string) bool) {
+	var currentKey []byte
+	var traverse func(int, int) bool
+	traverse = func(nodeId, bmIdx int) bool {
+		if getBit(ss.leaves, nodeId) != 0 {
+			if !f(string(currentKey)) {
+				return false
+			}
+		}
+
+		for ; ; bmIdx++ {
+			if getBit(ss.labelBitmap, bmIdx) != 0 {
+				return true
+			}
+			nextLabel := ss.labels[bmIdx-nodeId]
+			currentKey = append(currentKey, nextLabel)
+			nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
+			nextBmIdx := selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nextNodeId-1) + 1
+			if !traverse(nextNodeId, nextBmIdx) {
+				return false
+			}
+			currentKey = currentKey[:len(currentKey)-1]
+		}
+	}
+
+	traverse(0, 0)
+	return
+}
+
+func (ss *DomainSet) Foreach(f func(key string) bool) {
+	ss.keys(func(key string) bool {
+		return f(utils.Reverse(key))
+	})
+}
+
 func setBit(bm *[]uint64, i int, v int) {
 	for i>>6 >= len(*bm) {
 		*bm = append(*bm, 0)
diff --git a/component/trie/domain_set_test.go b/component/trie/domain_set_test.go
index 77106d5f..e343d11d 100644
--- a/component/trie/domain_set_test.go
+++ b/component/trie/domain_set_test.go
@@ -1,12 +1,29 @@
 package trie_test
 
 import (
+	"golang.org/x/exp/slices"
 	"testing"
 
 	"github.com/metacubex/mihomo/component/trie"
 	"github.com/stretchr/testify/assert"
 )
 
+func testDump(t *testing.T, tree *trie.DomainTrie[struct{}], set *trie.DomainSet) {
+	var dataSrc []string
+	tree.Foreach(func(domain string, data struct{}) bool {
+		dataSrc = append(dataSrc, domain)
+		return true
+	})
+	slices.Sort(dataSrc)
+	var dataSet []string
+	set.Foreach(func(key string) bool {
+		dataSet = append(dataSet, key)
+		return true
+	})
+	slices.Sort(dataSet)
+	assert.Equal(t, dataSrc, dataSet)
+}
+
 func TestDomainSet(t *testing.T) {
 	tree := trie.New[struct{}]()
 	domainSet := []string{
@@ -33,6 +50,7 @@ func TestDomainSet(t *testing.T) {
 	assert.True(t, set.Has("google.com"))
 	assert.False(t, set.Has("qq.com"))
 	assert.False(t, set.Has("www.baidu.com"))
+	testDump(t, tree, set)
 }
 
 func TestDomainSetComplexWildcard(t *testing.T) {
@@ -55,6 +73,7 @@ func TestDomainSetComplexWildcard(t *testing.T) {
 	assert.False(t, set.Has("google.com"))
 	assert.True(t, set.Has("www.baidu.com"))
 	assert.True(t, set.Has("test.test.baidu.com"))
+	testDump(t, tree, set)
 }
 
 func TestDomainSetWildcard(t *testing.T) {
@@ -82,4 +101,5 @@ func TestDomainSetWildcard(t *testing.T) {
 	assert.False(t, set.Has("a.www.google.com"))
 	assert.False(t, set.Has("test.qq.com"))
 	assert.False(t, set.Has("test.test.test.qq.com"))
+	testDump(t, tree, set)
 }
diff --git a/component/trie/domain_test.go b/component/trie/domain_test.go
index 4c5d8002..916f6107 100644
--- a/component/trie/domain_test.go
+++ b/component/trie/domain_test.go
@@ -121,8 +121,9 @@ func TestTrie_Foreach(t *testing.T) {
 		assert.NoError(t, tree.Insert(domain, localIP))
 	}
 	count := 0
-	tree.Foreach(func(domain string, data netip.Addr) {
+	tree.Foreach(func(domain string, data netip.Addr) bool {
 		count++
+		return true
 	})
 	assert.Equal(t, 7, count)
 }
diff --git a/docs/config.yaml b/docs/config.yaml
index 669c8be7..d7c686d0 100644
--- a/docs/config.yaml
+++ b/docs/config.yaml
@@ -944,10 +944,17 @@ rule-providers:
     type: file
   rule3:
     # mrs类型ruleset,目前仅支持domain和ipcidr(即不支持classical),
-    # behavior=domain,format=yaml 可以通过“mihomo convert-ruleset domain yaml XXX.yaml XXX.mrs”转换得到
-    # behavior=domain,format=text 可以通过“mihomo convert-ruleset domain text XXX.text XXX.mrs”转换得到
-    # behavior=ipcidr,format=yaml 可以通过“mihomo convert-ruleset ipcidr yaml XXX.yaml XXX.mrs”转换得到
-    # behavior=ipcidr,format=text 可以通过“mihomo convert-ruleset ipcidr text XXX.text XXX.mrs”转换得到
+    #
+    # 对于behavior=domain:
+    #  - format=yaml 可以通过“mihomo convert-ruleset domain yaml XXX.yaml XXX.mrs”转换到mrs格式
+    #  - format=text 可以通过“mihomo convert-ruleset domain text XXX.text XXX.mrs”转换到mrs格式
+    #  - XXX.mrs 可以通过"mihomo convert-ruleset domain mrs XXX.mrs XXX.text"转换回text格式(暂不支持转换回ymal格式)
+    #
+    # 对于behavior=ipcidr:
+    #  - format=yaml 可以通过“mihomo convert-ruleset ipcidr yaml XXX.yaml XXX.mrs”转换到mrs格式
+    #  - format=text 可以通过“mihomo convert-ruleset ipcidr text XXX.text XXX.mrs”转换到mrs格式
+    #  - XXX.mrs 可以通过"mihomo convert-ruleset ipcidr mrs XXX.mrs XXX.text"转换回text格式(暂不支持转换回ymal格式)
+    #
     type: http
     url: "url"
     format: mrs
diff --git a/rules/provider/domain_strategy.go b/rules/provider/domain_strategy.go
index a999f5bd..b893f038 100644
--- a/rules/provider/domain_strategy.go
+++ b/rules/provider/domain_strategy.go
@@ -9,6 +9,8 @@ import (
 	C "github.com/metacubex/mihomo/constant"
 	P "github.com/metacubex/mihomo/constant/provider"
 	"github.com/metacubex/mihomo/log"
+
+	"golang.org/x/exp/slices"
 )
 
 type domainStrategy struct {
@@ -78,6 +80,26 @@ func (d *domainStrategy) WriteMrs(w io.Writer) error {
 	return d.domainSet.WriteBin(w)
 }
 
+func (d *domainStrategy) DumpMrs(f func(key string) bool) {
+	if d.domainSet != nil {
+		var keys []string
+		d.domainSet.Foreach(func(key string) bool {
+			keys = append(keys, key)
+			return true
+		})
+		slices.Sort(keys)
+
+		for _, key := range keys {
+			if _, ok := slices.BinarySearch(keys, "+."+key); ok {
+				continue // ignore the rules added by trie internal processing
+			}
+			if !f(key) {
+				return
+			}
+		}
+	}
+}
+
 var _ mrsRuleStrategy = (*domainStrategy)(nil)
 
 func NewDomainStrategy() *domainStrategy {
diff --git a/rules/provider/ipcidr_strategy.go b/rules/provider/ipcidr_strategy.go
index 87cf7a2d..9efffed9 100644
--- a/rules/provider/ipcidr_strategy.go
+++ b/rules/provider/ipcidr_strategy.go
@@ -3,6 +3,7 @@ package provider
 import (
 	"errors"
 	"io"
+	"net/netip"
 
 	"github.com/metacubex/mihomo/component/cidr"
 	C "github.com/metacubex/mihomo/constant"
@@ -82,6 +83,14 @@ func (i *ipcidrStrategy) WriteMrs(w io.Writer) error {
 	return i.cidrSet.WriteBin(w)
 }
 
+func (i *ipcidrStrategy) DumpMrs(f func(key string) bool) {
+	if i.cidrSet != nil {
+		i.cidrSet.Foreach(func(prefix netip.Prefix) bool {
+			return f(prefix.String())
+		})
+	}
+}
+
 func (i *ipcidrStrategy) ToIpCidr() *netipx.IPSet {
 	return i.cidrSet.ToIPSet()
 }
diff --git a/rules/provider/mrs_converter.go b/rules/provider/mrs_converter.go
index a0830198..edc24e7e 100644
--- a/rules/provider/mrs_converter.go
+++ b/rules/provider/mrs_converter.go
@@ -3,6 +3,7 @@ package provider
 import (
 	"encoding/binary"
 	"errors"
+	"fmt"
 	"io"
 	"os"
 
@@ -21,6 +22,17 @@ func ConvertToMrs(buf []byte, behavior P.RuleBehavior, format P.RuleFormat, w io
 		return errors.New("empty rule")
 	}
 	if _strategy, ok := strategy.(mrsRuleStrategy); ok {
+		if format == P.MrsRule { // export to TextRule
+			_strategy.DumpMrs(func(key string) bool {
+				_, err = fmt.Fprintln(w, key)
+				if err != nil {
+					return false
+				}
+				return true
+			})
+			return nil
+		}
+
 		var encoder *zstd.Encoder
 		encoder, err = zstd.NewWriter(w)
 		if err != nil {
diff --git a/rules/provider/provider.go b/rules/provider/provider.go
index 8c5d7f94..b9524c35 100644
--- a/rules/provider/provider.go
+++ b/rules/provider/provider.go
@@ -58,6 +58,7 @@ type mrsRuleStrategy interface {
 	ruleStrategy
 	FromMrs(r io.Reader, count int) error
 	WriteMrs(w io.Writer) error
+	DumpMrs(f func(key string) bool)
 }
 
 func (rp *ruleSetProvider) Type() P.ProviderType {