From 1e22f4daa964c54abea4c8b0f09f8171398a2821 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 14 Mar 2025 12:07:23 +0800 Subject: [PATCH] chore: reduce data copying in quic sniffer and better handle data fragmentation and overlap --- common/utils/ranges.go | 22 ++++++ common/utils/ranges_test.go | 82 +++++++++++++++++++++ component/sniffer/quic_sniffer.go | 117 +++++++++++++++--------------- 3 files changed, 162 insertions(+), 59 deletions(-) create mode 100644 common/utils/ranges_test.go diff --git a/common/utils/ranges.go b/common/utils/ranges.go index f7dcf9c4..21eacda5 100644 --- a/common/utils/ranges.go +++ b/common/utils/ranges.go @@ -3,6 +3,7 @@ package utils import ( "errors" "fmt" + "sort" "strconv" "strings" @@ -149,3 +150,24 @@ func (ranges IntRanges[T]) Range(f func(t T) bool) { } } } + +func (ranges IntRanges[T]) Merge() (mergedRanges IntRanges[T]) { + if len(ranges) == 0 { + return + } + sort.Slice(ranges, func(i, j int) bool { + return ranges[i].Start() < ranges[j].Start() + }) + mergedRanges = ranges[:1] + var rangeIndex int + for _, r := range ranges[1:] { + if mergedRanges[rangeIndex].End()+1 > mergedRanges[rangeIndex].End() && // integer overflow + r.Start() > mergedRanges[rangeIndex].End()+1 { + mergedRanges = append(mergedRanges, r) + rangeIndex++ + } else if r.End() > mergedRanges[rangeIndex].End() { + mergedRanges[rangeIndex].end = r.End() + } + } + return +} diff --git a/common/utils/ranges_test.go b/common/utils/ranges_test.go new file mode 100644 index 00000000..3ae829d1 --- /dev/null +++ b/common/utils/ranges_test.go @@ -0,0 +1,82 @@ +package utils + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestMergeRanges(t *testing.T) { + t.Parallel() + for _, testRange := range []struct { + ranges IntRanges[uint16] + expected IntRanges[uint16] + }{ + { + ranges: IntRanges[uint16]{ + NewRange[uint16](0, 1), + NewRange[uint16](1, 2), + }, + expected: IntRanges[uint16]{ + NewRange[uint16](0, 2), + }, + }, + { + ranges: IntRanges[uint16]{ + NewRange[uint16](0, 3), + NewRange[uint16](5, 7), + NewRange[uint16](8, 9), + NewRange[uint16](10, 10), + }, + expected: IntRanges[uint16]{ + NewRange[uint16](0, 3), + NewRange[uint16](5, 10), + }, + }, + { + ranges: IntRanges[uint16]{ + NewRange[uint16](1, 3), + NewRange[uint16](2, 6), + NewRange[uint16](8, 10), + NewRange[uint16](15, 18), + }, + expected: IntRanges[uint16]{ + NewRange[uint16](1, 6), + NewRange[uint16](8, 10), + NewRange[uint16](15, 18), + }, + }, + { + ranges: IntRanges[uint16]{ + NewRange[uint16](1, 3), + NewRange[uint16](2, 7), + NewRange[uint16](2, 6), + }, + expected: IntRanges[uint16]{ + NewRange[uint16](1, 7), + }, + }, + { + ranges: IntRanges[uint16]{ + NewRange[uint16](1, 3), + NewRange[uint16](2, 6), + NewRange[uint16](2, 7), + }, + expected: IntRanges[uint16]{ + NewRange[uint16](1, 7), + }, + }, + { + ranges: IntRanges[uint16]{ + NewRange[uint16](1, 3), + NewRange[uint16](2, 65535), + NewRange[uint16](2, 7), + NewRange[uint16](3, 16), + }, + expected: IntRanges[uint16]{ + NewRange[uint16](1, 65535), + }, + }, + } { + assert.Equal(t, testRange.expected, testRange.ranges.Merge()) + } +} diff --git a/component/sniffer/quic_sniffer.go b/component/sniffer/quic_sniffer.go index e21380db..6cc377d2 100644 --- a/component/sniffer/quic_sniffer.go +++ b/component/sniffer/quic_sniffer.go @@ -11,6 +11,7 @@ import ( "time" "github.com/metacubex/mihomo/common/buf" + "github.com/metacubex/mihomo/common/pool" "github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant" @@ -31,6 +32,10 @@ const ( // Timeout before quic sniffer all packets quicWaitConn = time.Second * 3 + + // maxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams. + // This limits the size of the ClientHello and Certificates that can be received. + maxCryptoStreamOffset = 16 * (1 << 10) ) var ( @@ -72,27 +77,22 @@ func (sniffer *QuicSniffer) SniffData(b []byte) (string, error) { func (sniffer *QuicSniffer) WrapperSender(packetSender constant.PacketSender, override bool) constant.PacketSender { return &quicPacketSender{ sender: packetSender, - buffer: make([]quicDataBlock, 0), chClose: make(chan struct{}), override: override, } } -type quicDataBlock struct { - offset uint64 - length uint64 - data []byte -} - var _ constant.PacketSender = (*quicPacketSender)(nil) type quicPacketSender struct { lock sync.RWMutex - buffer []quicDataBlock - sender constant.PacketSender + ranges utils.IntRanges[uint64] + buffer []byte result string override bool + sender constant.PacketSender + chClose chan struct{} closed bool } @@ -144,11 +144,20 @@ func (q *quicPacketSender) Close() { func (q *quicPacketSender) close() { q.lock.Lock() + q.closeLocked() + q.lock.Unlock() +} + +func (q *quicPacketSender) closeLocked() { if !q.closed { close(q.chClose) q.closed = true + if q.buffer != nil { + _ = pool.Put(q.buffer) + q.buffer = nil + } + q.ranges = nil } - q.lock.Unlock() } func (q *quicPacketSender) readQuicData(b []byte) error { @@ -287,6 +296,14 @@ func (q *quicPacketSender) readQuicData(b []byte) error { buffer = buf.As(decrypted) for i := 0; !buffer.IsEmpty(); i++ { + q.lock.RLock() + if q.closed { + q.lock.RUnlock() + // close() was called, just return + return nil + } + q.lock.RUnlock() + frameType := byte(0x0) // Default to PADDING frame for frameType == 0x0 && !buffer.IsEmpty() { frameType, _ = buffer.ReadByte() @@ -337,26 +354,29 @@ func (q *quicPacketSender) readQuicData(b []byte) error { return io.ErrUnexpectedEOF } - q.lock.RLock() - if q.buffer == nil { - q.lock.RUnlock() - // sniffDone() was called, return the connection - return nil - } - q.lock.RUnlock() - - data = make([]byte, length) - - if _, err := buffer.Read(data); err != nil { // Field: Crypto Data - return io.ErrUnexpectedEOF + end := offset + length + if end > maxCryptoStreamOffset { + return io.ErrShortBuffer } q.lock.Lock() - q.buffer = append(q.buffer, quicDataBlock{ - offset: offset, - length: length, - data: data, - }) + if q.closed { + q.lock.Unlock() + // close() was called, just return + return nil + } + if q.buffer == nil { + q.buffer = pool.Get(maxCryptoStreamOffset)[:end] + } else if end > uint64(len(q.buffer)) { + q.buffer = q.buffer[:end] + } + target := q.buffer[offset:end] + if _, err := buffer.Read(target); err != nil { // Field: Crypto Data + q.lock.Unlock() + return io.ErrUnexpectedEOF + } + q.ranges = append(q.ranges, utils.NewRange(offset, end)) + q.ranges = q.ranges.Merge() q.lock.Unlock() case 0x1c: // CONNECTION_CLOSE frame, only 0x1c is permitted in initial packet if _, err = quicvarint.Read(buffer); err != nil { // Field: Error Code @@ -387,50 +407,29 @@ func (q *quicPacketSender) readQuicData(b []byte) error { func (q *quicPacketSender) tryAssemble() error { q.lock.RLock() - if q.buffer == nil { + if q.closed { q.lock.RUnlock() + // close() was called, just return return nil } - var frameLen uint64 - for _, fragment := range q.buffer { - frameLen += fragment.length - } - - buffer := buf.NewSize(int(frameLen)) - - var index uint64 - var length int - -loop: - for { - for _, fragment := range q.buffer { - if fragment.offset == index { - if _, err := buffer.Write(fragment.data); err != nil { - return err - } - index = fragment.offset + fragment.length - length++ - continue loop - } - } - - break - } - - domain, err := ReadClientHello(buffer.Bytes()) - if err != nil { + if len(q.ranges) != 1 || q.ranges[0].Start() != 0 || q.ranges[0].End() != uint64(len(q.buffer)) { q.lock.RUnlock() + return ErrNoClue + } + + domain, err := ReadClientHello(q.buffer) + q.lock.RUnlock() + if err != nil { return err } - q.lock.RUnlock() q.lock.Lock() q.result = *domain + q.closeLocked() q.lock.Unlock() - q.close() - return err + return nil } func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {