diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go index 5d457cf1..25834432 100644 --- a/component/sniffer/dispatcher.go +++ b/component/sniffer/dispatcher.go @@ -60,7 +60,10 @@ func (sd *Dispatcher) forceSniff(metadata *C.Metadata) bool { return false } -func (sd *Dispatcher) UDPSniff(packet C.PacketAdapter) bool { +// UDPSniff is called when a UDP NAT is created and passed the first initialization packet. +// It may return a wrapped packetSender if the sniffer process needs to wait for multiple packets. +// This function must be non-blocking, and any blocking operations should be done in the wrapped packetSender. +func (sd *Dispatcher) UDPSniff(packet C.PacketAdapter, packetSender C.PacketSender) C.PacketSender { metadata := packet.Metadata() if sd.shouldOverride(metadata) { for sniffer, config := range sd.sniffers { @@ -75,13 +78,13 @@ func (sd *Dispatcher) UDPSniff(packet C.PacketAdapter) bool { } sd.replaceDomain(metadata, host, overrideDest) - return true + return packetSender } } } } - return false + return packetSender } // TCPSniff returns true if the connection is sniffed to have a domain diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index a4486df7..b9507930 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -378,12 +378,14 @@ func handleUDPConn(packet C.PacketAdapter) { return } - if sniffingEnable && snifferDispatcher.Enable() { - snifferDispatcher.UDPSniff(packet) - } - key := packet.Key() - sender, loaded := natTable.GetOrCreate(key, newPacketSender) + sender, loaded := natTable.GetOrCreate(key, func() C.PacketSender { + sender := newPacketSender() + if sniffingEnable && snifferDispatcher.Enable() { + return snifferDispatcher.UDPSniff(packet, sender) + } + return sender + }) if !loaded { dial := func() (C.PacketConn, C.WriteBackProxy, error) { if err := sender.ResolveUDP(metadata); err != nil {