diff --git a/component/fakeip/pool_test.go b/component/fakeip/pool_test.go index 9c05a327..ed52fb6d 100644 --- a/component/fakeip/pool_test.go +++ b/component/fakeip/pool_test.go @@ -152,7 +152,8 @@ func TestPool_CycleUsed(t *testing.T) { func TestPool_Skip(t *testing.T) { ipnet := netip.MustParsePrefix("192.168.0.1/29") tree := trie.New[struct{}]() - tree.Insert("example.com", struct{}{}) + assert.NoError(t, tree.Insert("example.com", struct{}{})) + assert.False(t, tree.IsEmpty()) pools, tempfile, err := createPools(Options{ IPNet: ipnet, Size: 10, diff --git a/component/trie/domain.go b/component/trie/domain.go index 6d3e37f7..87dfeda6 100644 --- a/component/trie/domain.go +++ b/component/trie/domain.go @@ -126,7 +126,7 @@ func (t *DomainTrie[T]) Optimize() { func (t *DomainTrie[T]) Foreach(fn func(domain string, data T) bool) { for key, data := range t.root.getChildren() { recursion([]string{key}, data, fn) - if data != nil && data.inited { + if !data.isEmpty() { if !fn(joinDomain([]string{key}), data.data) { return } @@ -135,16 +135,16 @@ func (t *DomainTrie[T]) Foreach(fn func(domain string, data T) bool) { } func (t *DomainTrie[T]) IsEmpty() bool { - if t == nil { + if t == nil || t.root == nil { return true } - return t.root.isEmpty() + return len(t.root.getChildren()) == 0 } func recursion[T any](items []string, node *Node[T], fn func(domain string, data T) bool) bool { for key, data := range node.getChildren() { newItems := append([]string{key}, items...) - if data != nil && data.inited { + if !data.isEmpty() { domain := joinDomain(newItems) if domain[0] == domainStepByte { domain = complexWildcard + domain diff --git a/component/trie/domain_set_test.go b/component/trie/domain_set_test.go index e343d11d..38ba1622 100644 --- a/component/trie/domain_set_test.go +++ b/component/trie/domain_set_test.go @@ -40,6 +40,7 @@ func TestDomainSet(t *testing.T) { for _, domain := range domainSet { assert.NoError(t, tree.Insert(domain, struct{}{})) } + assert.False(t, tree.IsEmpty()) set := tree.NewDomainSet() assert.NotNil(t, set) assert.True(t, set.Has("test.cn")) @@ -68,6 +69,7 @@ func TestDomainSetComplexWildcard(t *testing.T) { for _, domain := range domainSet { assert.NoError(t, tree.Insert(domain, struct{}{})) } + assert.False(t, tree.IsEmpty()) set := tree.NewDomainSet() assert.NotNil(t, set) assert.False(t, set.Has("google.com")) @@ -90,6 +92,7 @@ func TestDomainSetWildcard(t *testing.T) { for _, domain := range domainSet { assert.NoError(t, tree.Insert(domain, struct{}{})) } + assert.False(t, tree.IsEmpty()) set := tree.NewDomainSet() assert.NotNil(t, set) assert.True(t, set.Has("www.baidu.com"))