From 129b2be9c1d3cad9155ce559b68949770f391216 Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Sun, 2 Mar 2025 23:53:33 -0500 Subject: [PATCH] Use capsulate protocol for large UDP packet - make datagram transport without mux functionality - it is now recommended to always pair with mux-cool (XUDP new tunnel non-zero session id) --- transport/internet/quic/conn.go | 136 ++++++++++++++++++++++++--- transport/internet/quic/dialer.go | 67 +------------ transport/internet/quic/hub.go | 9 +- transport/internet/quic/quic_test.go | 17 +++- 4 files changed, 140 insertions(+), 89 deletions(-) diff --git a/transport/internet/quic/conn.go b/transport/internet/quic/conn.go index e04a4137..f79571d5 100644 --- a/transport/internet/quic/conn.go +++ b/transport/internet/quic/conn.go @@ -6,22 +6,83 @@ import ( "github.com/quic-go/quic-go" "github.com/xtls/xray-core/common/buf" + "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/signal/done" ) +var MaxIncomingStreams = 2 +var currentStream = 0 + type interConn struct { - ctx context.Context - quicConn quic.Connection - local net.Addr - remote net.Addr + ctx context.Context + quicConn quic.Connection // small udp packet can be sent with Datagram directly + streams []quic.Stream // other packets can be sent via steam, it offer mux, reliability, fragmentation and ordering + readChannel chan readResult + done *done.Instance + local net.Addr + remote net.Addr } -func (c *interConn) Read(b []byte) (int, error) { - received, e := c.quicConn.ReceiveDatagram(c.ctx) - if e != nil { - return 0, e +type readResult struct { + buffer []byte + err error +} + +func NewConnInitReader(ctx context.Context, quicConn quic.Connection, done *done.Instance, remote net.Addr) *interConn { + c := &interConn{ + ctx: ctx, + quicConn: quicConn, + readChannel: make(chan readResult), + done: done, + local: quicConn.LocalAddr(), + remote: remote, } - nBytes := copy(b, received[:]) + go func() { + for { + received, e := c.quicConn.ReceiveDatagram(c.ctx) + c.readChannel <- readResult{buffer: received, err: e} + } + }() + go c.acceptStreams() + return c +} + +func (c *interConn) acceptStreams() { + for { + stream, err := c.quicConn.AcceptStream(context.Background()) + if err != nil { + errors.LogInfoInner(context.Background(), err, "failed to accept stream") + select { + case <-c.quicConn.Context().Done(): + return + case <-c.done.Wait(): + if err := c.quicConn.CloseWithError(0, ""); err != nil { + errors.LogInfoInner(context.Background(), err, "failed to close connection") + } + return + default: + time.Sleep(time.Second) + continue + } + } + go func() { + for { + received := make([]byte, buf.Size) + i, e := stream.Read(received) + c.readChannel <- readResult{buffer: received[:i], err: e} + } + }() + c.streams = append(c.streams, stream) + } +} + +func (c *interConn) Read(b []byte) (int, error) { + received := <- c.readChannel + if received.err != nil { + return 0, received.err + } + nBytes := copy(b, received.buffer[:]) return nBytes, nil } @@ -33,11 +94,37 @@ func (c *interConn) WriteMultiBuffer(mb buf.MultiBuffer) error { } func (c *interConn) Write(b []byte) (int, error) { - return len(b), c.quicConn.SendDatagram(b) + var err = c.quicConn.SendDatagram(b) + if _, ok := err.(*quic.DatagramTooLargeError); ok { + if len(c.streams) < MaxIncomingStreams { + stream, err := c.quicConn.OpenStream() + if err == nil { + c.streams = append(c.streams, stream) + } else { + errors.LogInfoInner(c.ctx, err, "failed to openStream: ") + } + } + currentStream++; + if currentStream > len(c.streams) - 1 { + currentStream = 0; + } + return c.streams[currentStream].Write(b) + } + if err != nil { + return 0, err + } + return len(b), nil } func (c *interConn) Close() error { - return nil + var err error + for _, s := range c.streams { + e := s.Close() + if e != nil { + err = e + } + } + return err } func (c *interConn) LocalAddr() net.Addr { @@ -49,13 +136,34 @@ func (c *interConn) RemoteAddr() net.Addr { } func (c *interConn) SetDeadline(t time.Time) error { - return nil + var err error + for _, s := range c.streams { + e := s.SetDeadline(t) + if e != nil { + err = e + } + } + return err } func (c *interConn) SetReadDeadline(t time.Time) error { - return nil + var err error + for _, s := range c.streams { + e := s.SetReadDeadline(t) + if e != nil { + err = e + } + } + return err } func (c *interConn) SetWriteDeadline(t time.Time) error { - return nil + var err error + for _, s := range c.streams { + e := s.SetWriteDeadline(t) + if e != nil { + err = e + } + } + return err } diff --git a/transport/internet/quic/dialer.go b/transport/internet/quic/dialer.go index 0d8bdf9f..65eb59a7 100644 --- a/transport/internet/quic/dialer.go +++ b/transport/internet/quic/dialer.go @@ -2,38 +2,18 @@ package quic import ( "context" - "sync" "time" "github.com/quic-go/quic-go" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tls" ) -type connectionContext struct { - rawConn *net.UDPConn - conn quic.Connection -} - -type clientConnections struct { - access sync.Mutex - conns map[net.Destination][]*connectionContext - // cleanup *task.Periodic -} - -func isActive(s quic.Connection) bool { - select { - case <-s.Context().Done(): - return false - default: - return true - } -} - func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { tlsConfig := tls.ConfigFromStreamSettings(streamSettings) if tlsConfig == nil { @@ -68,38 +48,11 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me config := streamSettings.ProtocolSettings.(*Config) - return client.openConnection(ctx, destAddr, config, tlsConfig, streamSettings.SocketSettings) + return openConnection(ctx, destAddr, config, tlsConfig, streamSettings.SocketSettings) } -func (s *clientConnections) openConnection(ctx context.Context, destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (stat.Connection, error) { - s.access.Lock() - defer s.access.Unlock() - - if s.conns == nil { - s.conns = make(map[net.Destination][]*connectionContext) - } - +func openConnection(ctx context.Context, destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (stat.Connection, error) { dest := net.DestinationFromAddr(destAddr) - - var conns []*connectionContext - if s, found := s.conns[dest]; found { - conns = s - } - - if len(conns) > 0 { - s := conns[len(conns)-1] - if isActive(s.conn) { - return &interConn{ - ctx: ctx, - quicConn: s.conn, - local: s.conn.LocalAddr(), - remote: destAddr, - }, nil - } else { - errors.LogInfo(ctx, "current quic connection is not active!") - } - } - errors.LogInfo(ctx, "dialing quic to ", dest) rawConn, err := internet.DialSystem(ctx, dest, sockopt) if err != nil { @@ -134,21 +87,9 @@ func (s *clientConnections) openConnection(ctx context.Context, destAddr net.Add return nil, err } - context := &connectionContext{ - conn: conn, - rawConn: udpConn, - } - s.conns[dest] = append(conns, context) - return &interConn{ - ctx: ctx, - quicConn: context.conn, - local: context.conn.LocalAddr(), - remote: destAddr, - }, nil + return NewConnInitReader(ctx, conn, done.New(), destAddr), nil } -var client clientConnections - func init() { common.Must(internet.RegisterTransportDialer(protocolName, Dial)) } diff --git a/transport/internet/quic/hub.go b/transport/internet/quic/hub.go index ed3f9aa8..bfeef877 100644 --- a/transport/internet/quic/hub.go +++ b/transport/internet/quic/hub.go @@ -33,12 +33,7 @@ func (l *Listener) keepAccepting(ctx context.Context) { time.Sleep(time.Second) continue } - l.addConn(&interConn{ - ctx: ctx, - quicConn: conn, - local: conn.LocalAddr(), - remote: conn.RemoteAddr(), - }) + l.addConn(NewConnInitReader(ctx, conn, l.done, conn.RemoteAddr())) } } @@ -81,7 +76,7 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti KeepAlivePeriod: 0, HandshakeIdleTimeout: time.Second * 8, MaxIdleTimeout: time.Second * 300, - MaxIncomingStreams: 32, + MaxIncomingStreams: 2, MaxIncomingUniStreams: -1, EnableDatagrams: true, } diff --git a/transport/internet/quic/quic_test.go b/transport/internet/quic/quic_test.go index d7ddd592..eb1d707b 100644 --- a/transport/internet/quic/quic_test.go +++ b/transport/internet/quic/quic_test.go @@ -18,7 +18,15 @@ import ( "github.com/xtls/xray-core/transport/internet/tls" ) -func TestQuicConnection(t *testing.T) { +func TestShortQuicConnection(t *testing.T) { + testQuicConnection(t, 1024) +} + +func TestLongQuicConnection(t *testing.T) { + testQuicConnection(t, 1500) +} + +func testQuicConnection(t *testing.T, dataLen int32) { port := udp.PickPort() listener, err := quic.Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{ @@ -69,15 +77,14 @@ func TestQuicConnection(t *testing.T) { common.Must(err) defer conn.Close() - const N = 1024 - b1 := make([]byte, N) + b1 := make([]byte, dataLen) common.Must2(rand.Read(b1)) b2 := buf.New() common.Must2(conn.Write(b1)) b2.Clear() - common.Must2(b2.ReadFullFrom(conn, N)) + common.Must2(b2.ReadFullFrom(conn, dataLen)) if r := cmp.Diff(b2.Bytes(), b1); r != "" { t.Error(r) } @@ -85,7 +92,7 @@ func TestQuicConnection(t *testing.T) { common.Must2(conn.Write(b1)) b2.Clear() - common.Must2(b2.ReadFullFrom(conn, N)) + common.Must2(b2.ReadFullFrom(conn, dataLen)) if r := cmp.Diff(b2.Bytes(), b1); r != "" { t.Error(r) }