From 1e989b68bd78b1b7792f08e7cbad7821bb130c97 Mon Sep 17 00:00:00 2001 From: Skyxim Date: Fri, 31 Mar 2023 20:43:49 +0800 Subject: [PATCH] fix: wildcard match --- component/trie/set_test.go | 12 ++++---- component/trie/sskv.go | 57 +++++++++++++++++--------------------- 2 files changed, 33 insertions(+), 36 deletions(-) diff --git a/component/trie/set_test.go b/component/trie/set_test.go index b3bb46e0..32cc7d99 100644 --- a/component/trie/set_test.go +++ b/component/trie/set_test.go @@ -38,12 +38,14 @@ func TestDomainWildcard(t *testing.T) { "*.baidu.com", "www.baidu.com", "*.*.qq.com", + "test.*.baidu.com", } set := trie.NewDomainSet(domainSet) assert.NotNil(t, set) - // assert.True(t, set.Has("www.baidu.com")) - // assert.False(t, set.Has("test.test.baidu.com")) - assert.True(t,set.Has("test.test.qq.com")) - assert.False(t,set.Has("test.qq.com")) - assert.False(t,set.Has("test.test.test.qq.com")) + assert.True(t, set.Has("www.baidu.com")) + assert.True(t, set.Has("test.test.baidu.com")) + assert.True(t, set.Has("test.test.qq.com")) + assert.False(t, set.Has("test.baidu.com")) + assert.False(t, set.Has("test.qq.com")) + assert.False(t, set.Has("test.test.test.qq.com")) } diff --git a/component/trie/sskv.go b/component/trie/sskv.go index 410015a1..b20e8280 100644 --- a/component/trie/sskv.go +++ b/component/trie/sskv.go @@ -108,28 +108,34 @@ func (ss *DomainSet) Has(key string) bool { // go to next level nodeId, bmIdx := 0, 0 type wildcardCursor struct { - index, bmIdx int - find bool - } - cursor := wildcardCursor{ - find: false, + nodeId, bmIdx, index int + find bool } + cursor := wildcardCursor{} for i := 0; i < len(key); i++ { + RESTART: c := key[i] for ; ; bmIdx++ { if getBit(ss.labelBitmap, bmIdx) != 0 { if cursor.find { - // gets the node next to the cursor - wildcardNextNodeId := countZeros(ss.labelBitmap, ss.ranks, cursor.bmIdx+1) - // next is a leaf, and the domain name has no next level - if getBit(ss.leaves, wildcardNextNodeId) != 0 && cursor.index == len(key) { - return true + // back wildcard and find next node + nextNodeId := countZeros(ss.labelBitmap, ss.ranks, cursor.bmIdx+1) + nextBmIdx := selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nextNodeId-1) + 1 + j := cursor.index + for ; j < len(key) && key[j] != domainStepByte; j++ { + } + if j == len(key) { + return getBit(ss.leaves, nextNodeId) != 0 + } + for ; ; nextBmIdx++ { + if ss.labels[nextBmIdx-nextNodeId] == domainStepByte { + bmIdx = nextBmIdx + nodeId = nextNodeId + i = j + cursor.find=false + goto RESTART + } } - // reset, and jump to the cursor location - cursor.find = false - i = cursor.index - bmIdx = cursor.bmIdx - break } return false } @@ -139,18 +145,10 @@ func (ss *DomainSet) Has(key string) bool { } else if ss.labels[bmIdx-nodeId] == wildcardByte { cursor.find = true cursor.bmIdx = bmIdx - // gets the first domain step that follows - // If not, it is the last domain level, which is represented by len(key) - if index := strings.Index(key[i:], domainStep); index > 0 { - cursor.index = index + i - 1 - } else { - cursor.index = len(key) - } - break + cursor.nodeId = nodeId + cursor.index = i } else if ss.labels[bmIdx-nodeId] == c { - if ss.labels[bmIdx-nodeId] == domainStepByte { - cursor.find = false - } + cursor.find=false break } } @@ -158,11 +156,8 @@ func (ss *DomainSet) Has(key string) bool { bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1 } - if getBit(ss.leaves, nodeId) != 0 { - return true - } else { - return false - } + return getBit(ss.leaves, nodeId) != 0 + } func setBit(bm *[]uint64, i int, v int) {