diff --git a/adapter/outbound/http.go b/adapter/outbound/http.go index 3e7060e6..b734290a 100644 --- a/adapter/outbound/http.go +++ b/adapter/outbound/http.go @@ -61,7 +61,7 @@ func (h *Http) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { // DialContext implements C.ProxyAdapter func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { - return h.DialContextWithDialer(ctx, dialer.Dialer{Options: h.Base.DialOptions(opts...)}, metadata) + return h.DialContextWithDialer(ctx, dialer.NewDialer(h.Base.DialOptions(opts...)...), metadata) } // DialContextWithDialer implements C.ProxyAdapter diff --git a/adapter/outbound/shadowsocks.go b/adapter/outbound/shadowsocks.go index 46a8b9bf..c318d263 100644 --- a/adapter/outbound/shadowsocks.go +++ b/adapter/outbound/shadowsocks.go @@ -84,7 +84,7 @@ func (ss *ShadowSocks) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, e // DialContext implements C.ProxyAdapter func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { - return ss.DialContextWithDialer(ctx, dialer.Dialer{Options: ss.Base.DialOptions(opts...)}, metadata) + return ss.DialContextWithDialer(ctx, dialer.NewDialer(ss.Base.DialOptions(opts...)...), metadata) } // DialContextWithDialer implements C.ProxyAdapter @@ -105,7 +105,7 @@ func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Diale // ListenPacketContext implements C.ProxyAdapter func (ss *ShadowSocks) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { - return ss.ListenPacketWithDialer(ctx, dialer.Dialer{Options: ss.Base.DialOptions(opts...)}, metadata) + return ss.ListenPacketWithDialer(ctx, dialer.NewDialer(ss.Base.DialOptions(opts...)...), metadata) } // ListenPacketWithDialer implements C.ProxyAdapter diff --git a/adapter/outbound/shadowsocksr.go b/adapter/outbound/shadowsocksr.go index 99b8edc3..e84de879 100644 --- a/adapter/outbound/shadowsocksr.go +++ b/adapter/outbound/shadowsocksr.go @@ -60,7 +60,7 @@ func (ssr *ShadowSocksR) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, // DialContext implements C.ProxyAdapter func (ssr *ShadowSocksR) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { - return ssr.DialContextWithDialer(ctx, dialer.Dialer{Options: ssr.Base.DialOptions(opts...)}, metadata) + return ssr.DialContextWithDialer(ctx, dialer.NewDialer(ssr.Base.DialOptions(opts...)...), metadata) } // DialContextWithDialer implements C.ProxyAdapter @@ -81,7 +81,7 @@ func (ssr *ShadowSocksR) DialContextWithDialer(ctx context.Context, dialer C.Dia // ListenPacketContext implements C.ProxyAdapter func (ssr *ShadowSocksR) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { - return ssr.ListenPacketWithDialer(ctx, dialer.Dialer{Options: ssr.Base.DialOptions(opts...)}, metadata) + return ssr.ListenPacketWithDialer(ctx, dialer.NewDialer(ssr.Base.DialOptions(opts...)...), metadata) } // ListenPacketWithDialer implements C.ProxyAdapter diff --git a/adapter/outbound/snell.go b/adapter/outbound/snell.go index bc1fa0c1..1331b526 100644 --- a/adapter/outbound/snell.go +++ b/adapter/outbound/snell.go @@ -78,7 +78,7 @@ func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d return NewConn(c, s), err } - return s.DialContextWithDialer(ctx, dialer.Dialer{Options: s.Base.DialOptions(opts...)}, metadata) + return s.DialContextWithDialer(ctx, dialer.NewDialer(s.Base.DialOptions(opts...)...), metadata) } // DialContextWithDialer implements C.ProxyAdapter @@ -99,7 +99,7 @@ func (s *Snell) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta // ListenPacketContext implements C.ProxyAdapter func (s *Snell) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { - return s.ListenPacketWithDialer(ctx, dialer.Dialer{Options: s.Base.DialOptions(opts...)}, metadata) + return s.ListenPacketWithDialer(ctx, dialer.NewDialer(s.Base.DialOptions(opts...)...), metadata) } // ListenPacketWithDialer implements C.ProxyAdapter diff --git a/adapter/outbound/socks5.go b/adapter/outbound/socks5.go index ccd13da7..28d41180 100644 --- a/adapter/outbound/socks5.go +++ b/adapter/outbound/socks5.go @@ -65,7 +65,7 @@ func (ss *Socks5) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) // DialContext implements C.ProxyAdapter func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { - return ss.DialContextWithDialer(ctx, dialer.Dialer{Options: ss.Base.DialOptions(opts...)}, metadata) + return ss.DialContextWithDialer(ctx, dialer.NewDialer(ss.Base.DialOptions(opts...)...), metadata) } // DialContextWithDialer implements C.ProxyAdapter diff --git a/adapter/outbound/trojan.go b/adapter/outbound/trojan.go index 9bfb0126..e3abeab2 100644 --- a/adapter/outbound/trojan.go +++ b/adapter/outbound/trojan.go @@ -120,7 +120,7 @@ func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata, opts ... return NewConn(c, t), nil } - return t.DialContextWithDialer(ctx, dialer.Dialer{Options: t.Base.DialOptions(opts...)}, metadata) + return t.DialContextWithDialer(ctx, dialer.NewDialer(t.Base.DialOptions(opts...)...), metadata) } // DialContextWithDialer implements C.ProxyAdapter @@ -164,7 +164,7 @@ func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata, pc := t.instance.PacketConn(c) return newPacketConn(pc, t), err } - return t.ListenPacketWithDialer(ctx, dialer.Dialer{Options: t.Base.DialOptions(opts...)}, metadata) + return t.ListenPacketWithDialer(ctx, dialer.NewDialer(t.Base.DialOptions(opts...)...), metadata) } // ListenPacketWithDialer implements C.ProxyAdapter diff --git a/adapter/outbound/tuic.go b/adapter/outbound/tuic.go index fa171187..6ffc0095 100644 --- a/adapter/outbound/tuic.go +++ b/adapter/outbound/tuic.go @@ -63,6 +63,15 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...di return NewConn(conn, t), err } +// DialContextWithDialer implements C.ProxyAdapter +func (t *Tuic) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) { + conn, err := t.client.DialContextWithDialer(ctx, dialer, metadata, t.dialWithDialer) + if err != nil { + return nil, err + } + return NewConn(conn, t), err +} + // ListenPacketContext implements C.ProxyAdapter func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { opts = t.Base.DialOptions(opts...) @@ -73,13 +82,31 @@ func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, op return newPacketConn(pc, t), nil } +// ListenPacketWithDialer implements C.ProxyAdapter +func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) { + pc, err := t.client.ListenPacketWithDialer(ctx, dialer, metadata, t.dialWithDialer) + if err != nil { + return nil, err + } + return newPacketConn(pc, t), nil +} + +// SupportWithDialer implements C.ProxyAdapter +func (t *Tuic) SupportWithDialer() bool { + return true +} + func (t *Tuic) dial(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { + return t.dialWithDialer(ctx, dialer.NewDialer(opts...)) +} + +func (t *Tuic) dialWithDialer(ctx context.Context, dialer C.Dialer) (pc net.PacketConn, addr net.Addr, err error) { udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", t.addr, t.prefer) if err != nil { return nil, nil, err } addr = udpAddr - pc, err = dialer.ListenPacket(ctx, dialer.ParseNetwork("udp", udpAddr.AddrPort().Addr()), "", opts...) + pc, err = dialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort()) if err != nil { return nil, nil, err } diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index 28ebfb05..5261fefd 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -219,7 +219,7 @@ func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d return NewConn(c, v), nil } - return v.DialContextWithDialer(ctx, dialer.Dialer{Options: v.Base.DialOptions(opts...)}, metadata) + return v.DialContextWithDialer(ctx, dialer.NewDialer(v.Base.DialOptions(opts...)...), metadata) } // DialContextWithDialer implements C.ProxyAdapter @@ -267,7 +267,7 @@ func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o return v.ListenPacketOnStreamConn(c, metadata) } - return v.ListenPacketWithDialer(ctx, dialer.Dialer{Options: v.Base.DialOptions(opts...)}, metadata) + return v.ListenPacketWithDialer(ctx, dialer.NewDialer(v.Base.DialOptions(opts...)...), metadata) } // ListenPacketWithDialer implements C.ProxyAdapter diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index 5697e056..2d1bd94d 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -232,7 +232,7 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d return NewConn(c, v), nil } - return v.DialContextWithDialer(ctx, dialer.Dialer{Options: v.Base.DialOptions(opts...)}, metadata) + return v.DialContextWithDialer(ctx, dialer.NewDialer(v.Base.DialOptions(opts...)...), metadata) } // DialContextWithDialer implements C.ProxyAdapter @@ -300,7 +300,7 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o }() c, err = v.StreamConn(c, metadata) - return v.ListenPacketWithDialer(ctx, dialer.Dialer{Options: v.Base.DialOptions(opts...)}, metadata) + return v.ListenPacketWithDialer(ctx, dialer.NewDialer(v.Base.DialOptions(opts...)...), metadata) } // ListenPacketWithDialer implements C.ProxyAdapter diff --git a/adapter/outboundgroup/relay.go b/adapter/outboundgroup/relay.go index 8205ee52..43ef81c9 100644 --- a/adapter/outboundgroup/relay.go +++ b/adapter/outboundgroup/relay.go @@ -59,7 +59,7 @@ func (r *Relay) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d return proxies[0].DialContext(ctx, metadata, r.Base.DialOptions(opts...)...) } var d C.Dialer - d = dialer.Dialer{Options: r.Base.DialOptions(opts...)} + d = dialer.NewDialer(r.Base.DialOptions(opts...)...) for _, proxy := range proxies[:len(proxies)-1] { d = proxyDialer{ proxy: proxy, @@ -93,7 +93,7 @@ func (r *Relay) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o } var d C.Dialer - d = dialer.Dialer{Options: r.Base.DialOptions(opts...)} + d = dialer.NewDialer(r.Base.DialOptions(opts...)...) for _, proxy := range proxies[:len(proxies)-1] { d = proxyDialer{ proxy: proxy, diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 256ff495..027b25b9 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -445,14 +445,19 @@ func concurrentIPv6DialContext(ctx context.Context, network, address string, opt return concurrentDialContext(ctx, network, ips, port, opt) } -type Dialer struct { - Options []Option +type dialer struct { + opt option } -func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return DialContext(ctx, network, address, d.Options...) +func (d dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return DialContext(ctx, network, address, withOption(d.opt)) } -func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) { - return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, d.Options...) +func (d dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) { + return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, withOption(d.opt)) +} + +func NewDialer(options ...Option) dialer { + opt := ApplyOptions(options...) + return dialer{opt: *opt} } diff --git a/component/dialer/options.go b/component/dialer/options.go index 98d0b8bd..8cd6fd39 100644 --- a/component/dialer/options.go +++ b/component/dialer/options.go @@ -68,3 +68,9 @@ func WithOnlySingleStack(isIPv4 bool) Option { } } } + +func withOption(o option) Option { + return func(opt *option) { + *opt = o + } +} diff --git a/transport/tuic/client.go b/transport/tuic/client.go index dcb4e3aa..418c4133 100644 --- a/transport/tuic/client.go +++ b/transport/tuic/client.go @@ -28,6 +28,7 @@ var ( ) type DialFunc func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) +type DialWithDialerFunc func(ctx context.Context, dialer C.Dialer) (pc net.PacketConn, addr net.Addr, err error) type ClientOption struct { TlsConfig *tls.Config diff --git a/transport/tuic/pool_client.go b/transport/tuic/pool_client.go index 9753da0d..304cc92d 100644 --- a/transport/tuic/pool_client.go +++ b/transport/tuic/pool_client.go @@ -37,9 +37,25 @@ func (t *PoolClient) DialContext(ctx context.Context, metadata *C.Metadata, dial newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { return t.dial(ctx, dialFn, opts...) } - conn, err := t.getClient(false, opts...).DialContext(ctx, metadata, newDialFn, opts...) + var o any = *dialer.ApplyOptions(opts...) + conn, err := t.getClient(false, o).DialContext(ctx, metadata, newDialFn, opts...) if errors.Is(err, TooManyOpenStreams) { - conn, err = t.newClient(false, opts...).DialContext(ctx, metadata, newDialFn, opts...) + conn, err = t.newClient(false, o).DialContext(ctx, metadata, newDialFn, opts...) + } + if err != nil { + return nil, err + } + return N.NewRefConn(conn, t), err +} + +func (t *PoolClient) DialContextWithDialer(ctx context.Context, d C.Dialer, metadata *C.Metadata, dialFn DialWithDialerFunc) (net.Conn, error) { + newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { + return dialFn(ctx, d) + } + var o any = d + conn, err := t.getClient(false, o).DialContext(ctx, metadata, newDialFn) + if errors.Is(err, TooManyOpenStreams) { + conn, err = t.newClient(false, o).DialContext(ctx, metadata, newDialFn) } if err != nil { return nil, err @@ -51,9 +67,25 @@ func (t *PoolClient) ListenPacketContext(ctx context.Context, metadata *C.Metada newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { return t.dial(ctx, dialFn, opts...) } - pc, err := t.getClient(true, opts...).ListenPacketContext(ctx, metadata, newDialFn, opts...) + var o any = *dialer.ApplyOptions(opts...) + pc, err := t.getClient(true, o).ListenPacketContext(ctx, metadata, newDialFn, opts...) if errors.Is(err, TooManyOpenStreams) { - pc, err = t.newClient(false, opts...).ListenPacketContext(ctx, metadata, newDialFn, opts...) + pc, err = t.newClient(true, o).ListenPacketContext(ctx, metadata, newDialFn, opts...) + } + if err != nil { + return nil, err + } + return N.NewRefPacketConn(pc, t), nil +} + +func (t *PoolClient) ListenPacketWithDialer(ctx context.Context, d C.Dialer, metadata *C.Metadata, dialFn DialWithDialerFunc) (net.PacketConn, error) { + newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) { + return dialFn(ctx, d) + } + var o any = d + pc, err := t.getClient(true, o).ListenPacketContext(ctx, metadata, newDialFn) + if errors.Is(err, TooManyOpenStreams) { + pc, err = t.newClient(true, o).ListenPacketContext(ctx, metadata, newDialFn) } if err != nil { return nil, err @@ -96,7 +128,7 @@ func (t *PoolClient) forceClose() { } } -func (t *PoolClient) newClient(udp bool, opts ...dialer.Option) *Client { +func (t *PoolClient) newClient(udp bool, o any) *Client { clients := t.tcpClients clientsMutex := t.tcpClientsMutex if udp { @@ -104,8 +136,6 @@ func (t *PoolClient) newClient(udp bool, opts ...dialer.Option) *Client { clientsMutex = t.udpClientsMutex } - var o any = *dialer.ApplyOptions(opts...) - clientsMutex.Lock() defer clientsMutex.Unlock() @@ -117,15 +147,13 @@ func (t *PoolClient) newClient(udp bool, opts ...dialer.Option) *Client { return client } -func (t *PoolClient) getClient(udp bool, opts ...dialer.Option) *Client { +func (t *PoolClient) getClient(udp bool, o any) *Client { clients := t.tcpClients clientsMutex := t.tcpClientsMutex if udp { clients = t.udpClients clientsMutex = t.udpClientsMutex } - - var o any = *dialer.ApplyOptions(opts...) var bestClient *Client func() { @@ -164,7 +192,7 @@ func (t *PoolClient) getClient(udp bool, opts ...dialer.Option) *Client { } if bestClient == nil { - return t.newClient(udp, opts...) + return t.newClient(udp, o) } else { bestClient.lastVisited = time.Now() return bestClient