chore: reduce data copying in quic sniffer and better handle data fragmentation and overlap

This commit is contained in:
wwqgtxx 2025-03-14 12:07:23 +08:00
parent a7a796bb30
commit 1e22f4daa9
3 changed files with 162 additions and 59 deletions

View file

@ -3,6 +3,7 @@ package utils
import (
"errors"
"fmt"
"sort"
"strconv"
"strings"
@ -149,3 +150,24 @@ func (ranges IntRanges[T]) Range(f func(t T) bool) {
}
}
}
func (ranges IntRanges[T]) Merge() (mergedRanges IntRanges[T]) {
if len(ranges) == 0 {
return
}
sort.Slice(ranges, func(i, j int) bool {
return ranges[i].Start() < ranges[j].Start()
})
mergedRanges = ranges[:1]
var rangeIndex int
for _, r := range ranges[1:] {
if mergedRanges[rangeIndex].End()+1 > mergedRanges[rangeIndex].End() && // integer overflow
r.Start() > mergedRanges[rangeIndex].End()+1 {
mergedRanges = append(mergedRanges, r)
rangeIndex++
} else if r.End() > mergedRanges[rangeIndex].End() {
mergedRanges[rangeIndex].end = r.End()
}
}
return
}

View file

@ -0,0 +1,82 @@
package utils
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestMergeRanges(t *testing.T) {
t.Parallel()
for _, testRange := range []struct {
ranges IntRanges[uint16]
expected IntRanges[uint16]
}{
{
ranges: IntRanges[uint16]{
NewRange[uint16](0, 1),
NewRange[uint16](1, 2),
},
expected: IntRanges[uint16]{
NewRange[uint16](0, 2),
},
},
{
ranges: IntRanges[uint16]{
NewRange[uint16](0, 3),
NewRange[uint16](5, 7),
NewRange[uint16](8, 9),
NewRange[uint16](10, 10),
},
expected: IntRanges[uint16]{
NewRange[uint16](0, 3),
NewRange[uint16](5, 10),
},
},
{
ranges: IntRanges[uint16]{
NewRange[uint16](1, 3),
NewRange[uint16](2, 6),
NewRange[uint16](8, 10),
NewRange[uint16](15, 18),
},
expected: IntRanges[uint16]{
NewRange[uint16](1, 6),
NewRange[uint16](8, 10),
NewRange[uint16](15, 18),
},
},
{
ranges: IntRanges[uint16]{
NewRange[uint16](1, 3),
NewRange[uint16](2, 7),
NewRange[uint16](2, 6),
},
expected: IntRanges[uint16]{
NewRange[uint16](1, 7),
},
},
{
ranges: IntRanges[uint16]{
NewRange[uint16](1, 3),
NewRange[uint16](2, 6),
NewRange[uint16](2, 7),
},
expected: IntRanges[uint16]{
NewRange[uint16](1, 7),
},
},
{
ranges: IntRanges[uint16]{
NewRange[uint16](1, 3),
NewRange[uint16](2, 65535),
NewRange[uint16](2, 7),
NewRange[uint16](3, 16),
},
expected: IntRanges[uint16]{
NewRange[uint16](1, 65535),
},
},
} {
assert.Equal(t, testRange.expected, testRange.ranges.Merge())
}
}

View file

@ -11,6 +11,7 @@ import (
"time"
"github.com/metacubex/mihomo/common/buf"
"github.com/metacubex/mihomo/common/pool"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/constant"
C "github.com/metacubex/mihomo/constant"
@ -31,6 +32,10 @@ const (
// Timeout before quic sniffer all packets
quicWaitConn = time.Second * 3
// maxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams.
// This limits the size of the ClientHello and Certificates that can be received.
maxCryptoStreamOffset = 16 * (1 << 10)
)
var (
@ -72,27 +77,22 @@ func (sniffer *QuicSniffer) SniffData(b []byte) (string, error) {
func (sniffer *QuicSniffer) WrapperSender(packetSender constant.PacketSender, override bool) constant.PacketSender {
return &quicPacketSender{
sender: packetSender,
buffer: make([]quicDataBlock, 0),
chClose: make(chan struct{}),
override: override,
}
}
type quicDataBlock struct {
offset uint64
length uint64
data []byte
}
var _ constant.PacketSender = (*quicPacketSender)(nil)
type quicPacketSender struct {
lock sync.RWMutex
buffer []quicDataBlock
sender constant.PacketSender
ranges utils.IntRanges[uint64]
buffer []byte
result string
override bool
sender constant.PacketSender
chClose chan struct{}
closed bool
}
@ -144,11 +144,20 @@ func (q *quicPacketSender) Close() {
func (q *quicPacketSender) close() {
q.lock.Lock()
q.closeLocked()
q.lock.Unlock()
}
func (q *quicPacketSender) closeLocked() {
if !q.closed {
close(q.chClose)
q.closed = true
if q.buffer != nil {
_ = pool.Put(q.buffer)
q.buffer = nil
}
q.ranges = nil
}
q.lock.Unlock()
}
func (q *quicPacketSender) readQuicData(b []byte) error {
@ -287,6 +296,14 @@ func (q *quicPacketSender) readQuicData(b []byte) error {
buffer = buf.As(decrypted)
for i := 0; !buffer.IsEmpty(); i++ {
q.lock.RLock()
if q.closed {
q.lock.RUnlock()
// close() was called, just return
return nil
}
q.lock.RUnlock()
frameType := byte(0x0) // Default to PADDING frame
for frameType == 0x0 && !buffer.IsEmpty() {
frameType, _ = buffer.ReadByte()
@ -337,26 +354,29 @@ func (q *quicPacketSender) readQuicData(b []byte) error {
return io.ErrUnexpectedEOF
}
q.lock.RLock()
if q.buffer == nil {
q.lock.RUnlock()
// sniffDone() was called, return the connection
return nil
}
q.lock.RUnlock()
data = make([]byte, length)
if _, err := buffer.Read(data); err != nil { // Field: Crypto Data
return io.ErrUnexpectedEOF
end := offset + length
if end > maxCryptoStreamOffset {
return io.ErrShortBuffer
}
q.lock.Lock()
q.buffer = append(q.buffer, quicDataBlock{
offset: offset,
length: length,
data: data,
})
if q.closed {
q.lock.Unlock()
// close() was called, just return
return nil
}
if q.buffer == nil {
q.buffer = pool.Get(maxCryptoStreamOffset)[:end]
} else if end > uint64(len(q.buffer)) {
q.buffer = q.buffer[:end]
}
target := q.buffer[offset:end]
if _, err := buffer.Read(target); err != nil { // Field: Crypto Data
q.lock.Unlock()
return io.ErrUnexpectedEOF
}
q.ranges = append(q.ranges, utils.NewRange(offset, end))
q.ranges = q.ranges.Merge()
q.lock.Unlock()
case 0x1c: // CONNECTION_CLOSE frame, only 0x1c is permitted in initial packet
if _, err = quicvarint.Read(buffer); err != nil { // Field: Error Code
@ -387,50 +407,29 @@ func (q *quicPacketSender) readQuicData(b []byte) error {
func (q *quicPacketSender) tryAssemble() error {
q.lock.RLock()
if q.buffer == nil {
if q.closed {
q.lock.RUnlock()
// close() was called, just return
return nil
}
var frameLen uint64
for _, fragment := range q.buffer {
frameLen += fragment.length
}
buffer := buf.NewSize(int(frameLen))
var index uint64
var length int
loop:
for {
for _, fragment := range q.buffer {
if fragment.offset == index {
if _, err := buffer.Write(fragment.data); err != nil {
return err
}
index = fragment.offset + fragment.length
length++
continue loop
}
}
break
}
domain, err := ReadClientHello(buffer.Bytes())
if err != nil {
if len(q.ranges) != 1 || q.ranges[0].Start() != 0 || q.ranges[0].End() != uint64(len(q.buffer)) {
q.lock.RUnlock()
return ErrNoClue
}
domain, err := ReadClientHello(q.buffer)
q.lock.RUnlock()
if err != nil {
return err
}
q.lock.RUnlock()
q.lock.Lock()
q.result = *domain
q.closeLocked()
q.lock.Unlock()
q.close()
return err
return nil
}
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {