From 7de24e26b4d9dc4b0b23e4e3da81d416fa88a67b Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Thu, 3 Apr 2025 23:41:24 +0800 Subject: [PATCH] fix: StreamGunWithConn not synchronously close the incoming net.Conn --- adapter/outbound/base.go | 2 +- adapter/outbound/trojan.go | 4 ++-- adapter/outbound/vless.go | 4 ++-- adapter/outbound/vmess.go | 4 ++-- transport/gun/gun.go | 48 +++++++++++++++++++++++++++----------- transport/gun/server.go | 3 +-- 6 files changed, 42 insertions(+), 23 deletions(-) diff --git a/adapter/outbound/base.go b/adapter/outbound/base.go index 84e9b7c2..9c3515c0 100644 --- a/adapter/outbound/base.go +++ b/adapter/outbound/base.go @@ -281,7 +281,7 @@ func newPacketConn(pc net.PacketConn, a C.ProxyAdapter) C.PacketConn { epc := N.NewEnhancePacketConn(pc) if _, ok := pc.(syscall.Conn); !ok { // exclusion system conn like *net.UDPConn epc = N.NewDeadlineEnhancePacketConn(epc) // most conn from outbound can't handle readDeadline correctly - epc = N.NewRefPacketConn(epc, a) // add ref for autoCloseProxyAdapter + epc = N.NewRefPacketConn(epc, a) // add ref for autoCloseProxyAdapter } return &packetConn{epc, []string{a.Name()}, a.Name(), utils.NewUUIDV4().String(), parseRemoteDestination(a.Addr())} } diff --git a/adapter/outbound/trojan.go b/adapter/outbound/trojan.go index c62c1eb4..b37235e0 100644 --- a/adapter/outbound/trojan.go +++ b/adapter/outbound/trojan.go @@ -313,7 +313,7 @@ func NewTrojan(option TrojanOption) (*Trojan, error) { } if option.Network == "grpc" { - dialFn := func(network, addr string) (net.Conn, error) { + dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { var err error var cDialer C.Dialer = dialer.NewDialer(t.Base.DialOptions()...) if len(t.option.DialerProxy) > 0 { @@ -322,7 +322,7 @@ func NewTrojan(option TrojanOption) (*Trojan, error) { return nil, err } } - c, err := cDialer.DialContext(context.Background(), "tcp", t.addr) + c, err := cDialer.DialContext(ctx, "tcp", t.addr) if err != nil { return nil, fmt.Errorf("%s connect error: %s", t.addr, err.Error()) } diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index 6609812e..8f5de920 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -571,7 +571,7 @@ func NewVless(option VlessOption) (*Vless, error) { option.HTTP2Opts.Host = append(option.HTTP2Opts.Host, "www.example.com") } case "grpc": - dialFn := func(network, addr string) (net.Conn, error) { + dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { var err error var cDialer C.Dialer = dialer.NewDialer(v.Base.DialOptions()...) if len(v.option.DialerProxy) > 0 { @@ -580,7 +580,7 @@ func NewVless(option VlessOption) (*Vless, error) { return nil, err } } - c, err := cDialer.DialContext(context.Background(), "tcp", v.addr) + c, err := cDialer.DialContext(ctx, "tcp", v.addr) if err != nil { return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) } diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index 54a25711..00f8a31e 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -478,7 +478,7 @@ func NewVmess(option VmessOption) (*Vmess, error) { option.HTTP2Opts.Host = append(option.HTTP2Opts.Host, "www.example.com") } case "grpc": - dialFn := func(network, addr string) (net.Conn, error) { + dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { var err error var cDialer C.Dialer = dialer.NewDialer(v.Base.DialOptions()...) if len(v.option.DialerProxy) > 0 { @@ -487,7 +487,7 @@ func NewVmess(option VmessOption) (*Vmess, error) { return nil, err } } - c, err := cDialer.DialContext(context.Background(), "tcp", v.addr) + c, err := cDialer.DialContext(ctx, "tcp", v.addr) if err != nil { return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) } diff --git a/transport/gun/gun.go b/transport/gun/gun.go index 12a11884..7738105a 100644 --- a/transport/gun/gun.go +++ b/transport/gun/gun.go @@ -36,12 +36,12 @@ var defaultHeader = http.Header{ "user-agent": []string{"grpc-go/1.36.0"}, } -type DialFn = func(network, addr string) (net.Conn, error) +type DialFn = func(ctx context.Context, network, addr string) (net.Conn, error) type Conn struct { - initFn func() (io.ReadCloser, netAddr, error) - writer io.Writer - flusher http.Flusher + initFn func() (io.ReadCloser, netAddr, error) + writer io.Writer + closer io.Closer netAddr reader io.ReadCloser @@ -149,8 +149,8 @@ func (g *Conn) Write(b []byte) (n int, err error) { err = g.err } - if g.flusher != nil { - g.flusher.Flush() + if flusher, ok := g.writer.(http.Flusher); ok { + flusher.Flush() } return len(b), err @@ -172,8 +172,8 @@ func (g *Conn) WriteBuffer(buffer *buf.Buffer) error { err = g.err } - if g.flusher != nil { - g.flusher.Flush() + if flusher, ok := g.writer.(http.Flusher); ok { + flusher.Flush() } return err @@ -185,14 +185,27 @@ func (g *Conn) FrontHeadroom() int { func (g *Conn) Close() error { g.close.Store(true) + var errorArr []error + if reader := g.reader; reader != nil { - reader.Close() + if err := reader.Close(); err != nil { + errorArr = append(errorArr, err) + } } if closer, ok := g.writer.(io.Closer); ok { - return closer.Close() + if err := closer.Close(); err != nil { + errorArr = append(errorArr, err) + } } - return nil + + if closer := g.closer; closer != nil { + if err := closer.Close(); err != nil { + errorArr = append(errorArr, err) + } + } + + return errors.Join(errorArr...) } func (g *Conn) SetReadDeadline(t time.Time) error { return g.SetDeadline(t) } @@ -212,7 +225,7 @@ func (g *Conn) SetDeadline(t time.Time) error { func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config, Fingerprint string, realityConfig *tlsC.RealityConfig) *TransportWrap { dialFunc := func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { - pconn, err := dialFn(network, addr) + pconn, err := dialFn(ctx, network, addr) if err != nil { return nil, err } @@ -327,10 +340,17 @@ func StreamGunWithTransport(transport *TransportWrap, cfg *Config) (net.Conn, er } func StreamGunWithConn(conn net.Conn, tlsConfig *tls.Config, cfg *Config, realityConfig *tlsC.RealityConfig) (net.Conn, error) { - dialFn := func(network, addr string) (net.Conn, error) { + dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { return conn, nil } transport := NewHTTP2Client(dialFn, tlsConfig, cfg.ClientFingerprint, realityConfig) - return StreamGunWithTransport(transport, cfg) + c, err := StreamGunWithTransport(transport, cfg) + if err != nil { + return nil, err + } + if c, ok := c.(*Conn); ok { // The incoming net.Conn should be closed synchronously with the generated gun.Conn + c.closer = conn + } + return c, nil } diff --git a/transport/gun/server.go b/transport/gun/server.go index 8b506542..f4f6e948 100644 --- a/transport/gun/server.go +++ b/transport/gun/server.go @@ -56,8 +56,7 @@ func NewServerHandler(options ServerOption) http.Handler { } return request.Body, nAddr, nil }, - writer: writer, - flusher: writer.(http.Flusher), + writer: writer, } wrapper := &h2ConnWrapper{