From da5a28a088091b86ac5b70ca732fc11cdb4c43fe Mon Sep 17 00:00:00 2001
From: dyhkwong <50692134+dyhkwong@users.noreply.github.com>
Date: Mon, 15 Jan 2024 23:33:15 +0800
Subject: [PATCH] Fix #2654 (#2941)

* fix udp dispatcher

* fix test
---
 transport/internet/udp/dispatcher.go | 29 +++++++++++++++-------------
 1 file changed, 16 insertions(+), 13 deletions(-)

diff --git a/transport/internet/udp/dispatcher.go b/transport/internet/udp/dispatcher.go
index 32c8c8ac..c29d4b13 100644
--- a/transport/internet/udp/dispatcher.go
+++ b/transport/internet/udp/dispatcher.go
@@ -28,7 +28,7 @@ type connEntry struct {
 
 type Dispatcher struct {
 	sync.RWMutex
-	conns      map[net.Destination]*connEntry
+	conn       *connEntry
 	dispatcher routing.Dispatcher
 	callback   ResponseCallback
 	callClose  func() error
@@ -36,19 +36,18 @@ type Dispatcher struct {
 
 func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
 	return &Dispatcher{
-		conns:      make(map[net.Destination]*connEntry),
 		dispatcher: dispatcher,
 		callback:   callback,
 	}
 }
 
-func (v *Dispatcher) RemoveRay(dest net.Destination) {
+func (v *Dispatcher) RemoveRay() {
 	v.Lock()
 	defer v.Unlock()
-	if conn, found := v.conns[dest]; found {
-		common.Close(conn.link.Reader)
-		common.Close(conn.link.Writer)
-		delete(v.conns, dest)
+	if v.conn != nil {
+		common.Close(v.conn.link.Reader)
+		common.Close(v.conn.link.Writer)
+		v.conn = nil
 	}
 }
 
@@ -56,8 +55,8 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*
 	v.Lock()
 	defer v.Unlock()
 
-	if entry, found := v.conns[dest]; found {
-		return entry, nil
+	if v.conn != nil {
+		return v.conn, nil
 	}
 
 	newError("establishing new connection for ", dest).WriteToLog()
@@ -65,7 +64,7 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*
 	ctx, cancel := context.WithCancel(ctx)
 	removeRay := func() {
 		cancel()
-		v.RemoveRay(dest)
+		v.RemoveRay()
 	}
 	timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
 
@@ -79,7 +78,7 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*
 		timer:  timer,
 		cancel: removeRay,
 	}
-	v.conns[dest] = entry
+	v.conn = entry
 	go handleInput(ctx, entry, dest, v.callback, v.callClose)
 	return entry, nil
 }
@@ -130,6 +129,9 @@ func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, cal
 		}
 		timer.Update()
 		for _, b := range mb {
+			if b.UDP != nil {
+				dest = *b.UDP
+			}
 			callback(ctx, &udp.Packet{
 				Payload: b,
 				Source:  dest,
@@ -153,7 +155,6 @@ func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.Pac
 	}
 
 	d := &Dispatcher{
-		conns:      make(map[net.Destination]*connEntry),
 		dispatcher: dispatcher,
 		callback:   c.callback,
 		callClose:  c.Close,
@@ -199,7 +200,9 @@ func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {
 	n := copy(raw, p)
 	buffer.Resize(0, int32(n))
 
-	c.dispatcher.Dispatch(c.ctx, net.DestinationFromAddr(addr), buffer)
+	destination := net.DestinationFromAddr(addr)
+	buffer.UDP = &destination
+	c.dispatcher.Dispatch(c.ctx, destination, buffer)
 	return n, nil
 }