diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 6e1f2426..d7e7072a 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -20,34 +20,16 @@ const ( DefaultUDPTimeout = DefaultTCPTimeout ) -type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) +type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) var ( dialMux sync.Mutex - IP4PEnable bool actualSingleStackDialContext = serialSingleStackDialContext actualDualStackDialContext = serialDualStackDialContext tcpConcurrent = false fallbackTimeout = 300 * time.Millisecond ) -func applyOptions(options ...Option) *option { - opt := &option{ - interfaceName: DefaultInterface.Load(), - routingMark: int(DefaultRoutingMark.Load()), - } - - for _, o := range DefaultOptions { - o(opt) - } - - for _, o := range options { - o(opt) - } - - return opt -} - func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) { opt := applyOptions(options...) @@ -77,38 +59,43 @@ func DialContext(ctx context.Context, network, address string, options ...Option } func ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort, options ...Option) (net.PacketConn, error) { - cfg := applyOptions(options...) + opt := applyOptions(options...) lc := &net.ListenConfig{} - if cfg.addrReuse { + if opt.addrReuse { addrReuseToListenConfig(lc) } if DefaultSocketHook != nil { // ignore interfaceName, routingMark when DefaultSocketHook not null (in CMFA) socketHookToListenConfig(lc) } else { - interfaceName := cfg.interfaceName - if interfaceName == "" { + if opt.interfaceName == "" { + opt.interfaceName = DefaultInterface.Load() + } + if opt.interfaceName == "" { if finder := DefaultInterfaceFinder.Load(); finder != nil { - interfaceName = finder.FindInterfaceName(rAddrPort.Addr()) + opt.interfaceName = finder.FindInterfaceName(rAddrPort.Addr()) } } if rAddrPort.Addr().Unmap().IsLoopback() { // avoid "The requested address is not valid in its context." - interfaceName = "" + opt.interfaceName = "" } - if interfaceName != "" { + if opt.interfaceName != "" { bind := bindIfaceToListenConfig - if cfg.fallbackBind { + if opt.fallbackBind { bind = fallbackBindIfaceToListenConfig } - addr, err := bind(interfaceName, lc, network, address, rAddrPort) + addr, err := bind(opt.interfaceName, lc, network, address, rAddrPort) if err != nil { return nil, err } address = addr } - if cfg.routingMark != 0 { - bindMarkToListenConfig(cfg.routingMark, lc, network, address) + if opt.routingMark == 0 { + opt.routingMark = int(DefaultRoutingMark.Load()) + } + if opt.routingMark != 0 { + bindMarkToListenConfig(opt.routingMark, lc, network, address) } } @@ -134,7 +121,7 @@ func GetTcpConcurrent() bool { return tcpConcurrent } -func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) { +func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt option) (net.Conn, error) { var address string destination, port = resolver.LookupIP4P(destination, port) address = net.JoinHostPort(destination.String(), port) @@ -159,21 +146,26 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po if DefaultSocketHook != nil { // ignore interfaceName, routingMark and tfo when DefaultSocketHook not null (in CMFA) socketHookToToDialer(dialer) } else { - interfaceName := opt.interfaceName // don't change the "opt", it's a pointer - if interfaceName == "" { + if opt.interfaceName == "" { + opt.interfaceName = DefaultInterface.Load() + } + if opt.interfaceName == "" { if finder := DefaultInterfaceFinder.Load(); finder != nil { - interfaceName = finder.FindInterfaceName(destination) + opt.interfaceName = finder.FindInterfaceName(destination) } } - if interfaceName != "" { + if opt.interfaceName != "" { bind := bindIfaceToDialer if opt.fallbackBind { bind = fallbackBindIfaceToDialer } - if err := bind(interfaceName, dialer, network, destination); err != nil { + if err := bind(opt.interfaceName, dialer, network, destination); err != nil { return nil, err } } + if opt.routingMark == 0 { + opt.routingMark = int(DefaultRoutingMark.Load()) + } if opt.routingMark != 0 { bindMarkToDialer(opt.routingMark, dialer, network, destination) } @@ -185,26 +177,26 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po return dialer.DialContext(ctx, network, address) } -func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { +func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) } -func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { +func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt) } -func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { +func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { return parallelDialContext(ctx, network, ips, port, opt) } -func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { +func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { if opt.prefer != 4 && opt.prefer != 6 { return parallelDialContext(ctx, network, ips, port, opt) } return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt) } -func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { +func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { ipv4s, ipv6s := resolver.SortationAddr(ips) if len(ipv4s) == 0 && len(ipv6s) == 0 { return nil, ErrorNoIpAddress @@ -285,7 +277,7 @@ loop: return nil, errors.Join(errs...) } -func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { +func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { if len(ips) == 0 { return nil, ErrorNoIpAddress } @@ -324,7 +316,7 @@ func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, return nil, os.ErrDeadlineExceeded } -func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { +func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { if len(ips) == 0 { return nil, ErrorNoIpAddress } @@ -390,5 +382,5 @@ func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddr func NewDialer(options ...Option) Dialer { opt := applyOptions(options...) - return Dialer{Opt: *opt} + return Dialer{Opt: opt} } diff --git a/component/dialer/options.go b/component/dialer/options.go index d15d36e8..bb978cdb 100644 --- a/component/dialer/options.go +++ b/component/dialer/options.go @@ -10,7 +10,6 @@ import ( ) var ( - DefaultOptions []Option DefaultInterface = atomic.NewTypedValue[string]("") DefaultRoutingMark = atomic.NewInt32(0) @@ -117,9 +116,13 @@ func WithOption(o option) Option { } func IsZeroOptions(opts []Option) bool { - var opt option - for _, o := range opts { + return applyOptions(opts...) == option{} +} + +func applyOptions(options ...Option) option { + opt := option{} + for _, o := range options { o(&opt) } - return opt == option{} + return opt }