chore: apply the default interface/mark of the dialer in the final stage

This commit is contained in:
wwqgtxx 2025-04-18 20:16:51 +08:00
parent 9c5067e519
commit 619c9dc0c6
2 changed files with 43 additions and 48 deletions

View file

@ -20,34 +20,16 @@ const (
DefaultUDPTimeout = DefaultTCPTimeout
)
type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error)
type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error)
var (
dialMux sync.Mutex
IP4PEnable bool
actualSingleStackDialContext = serialSingleStackDialContext
actualDualStackDialContext = serialDualStackDialContext
tcpConcurrent = false
fallbackTimeout = 300 * time.Millisecond
)
func applyOptions(options ...Option) *option {
opt := &option{
interfaceName: DefaultInterface.Load(),
routingMark: int(DefaultRoutingMark.Load()),
}
for _, o := range DefaultOptions {
o(opt)
}
for _, o := range options {
o(opt)
}
return opt
}
func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) {
opt := applyOptions(options...)
@ -77,38 +59,43 @@ func DialContext(ctx context.Context, network, address string, options ...Option
}
func ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort, options ...Option) (net.PacketConn, error) {
cfg := applyOptions(options...)
opt := applyOptions(options...)
lc := &net.ListenConfig{}
if cfg.addrReuse {
if opt.addrReuse {
addrReuseToListenConfig(lc)
}
if DefaultSocketHook != nil { // ignore interfaceName, routingMark when DefaultSocketHook not null (in CMFA)
socketHookToListenConfig(lc)
} else {
interfaceName := cfg.interfaceName
if interfaceName == "" {
if opt.interfaceName == "" {
opt.interfaceName = DefaultInterface.Load()
}
if opt.interfaceName == "" {
if finder := DefaultInterfaceFinder.Load(); finder != nil {
interfaceName = finder.FindInterfaceName(rAddrPort.Addr())
opt.interfaceName = finder.FindInterfaceName(rAddrPort.Addr())
}
}
if rAddrPort.Addr().Unmap().IsLoopback() {
// avoid "The requested address is not valid in its context."
interfaceName = ""
opt.interfaceName = ""
}
if interfaceName != "" {
if opt.interfaceName != "" {
bind := bindIfaceToListenConfig
if cfg.fallbackBind {
if opt.fallbackBind {
bind = fallbackBindIfaceToListenConfig
}
addr, err := bind(interfaceName, lc, network, address, rAddrPort)
addr, err := bind(opt.interfaceName, lc, network, address, rAddrPort)
if err != nil {
return nil, err
}
address = addr
}
if cfg.routingMark != 0 {
bindMarkToListenConfig(cfg.routingMark, lc, network, address)
if opt.routingMark == 0 {
opt.routingMark = int(DefaultRoutingMark.Load())
}
if opt.routingMark != 0 {
bindMarkToListenConfig(opt.routingMark, lc, network, address)
}
}
@ -134,7 +121,7 @@ func GetTcpConcurrent() bool {
return tcpConcurrent
}
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt option) (net.Conn, error) {
var address string
destination, port = resolver.LookupIP4P(destination, port)
address = net.JoinHostPort(destination.String(), port)
@ -159,21 +146,26 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
if DefaultSocketHook != nil { // ignore interfaceName, routingMark and tfo when DefaultSocketHook not null (in CMFA)
socketHookToToDialer(dialer)
} else {
interfaceName := opt.interfaceName // don't change the "opt", it's a pointer
if interfaceName == "" {
if opt.interfaceName == "" {
opt.interfaceName = DefaultInterface.Load()
}
if opt.interfaceName == "" {
if finder := DefaultInterfaceFinder.Load(); finder != nil {
interfaceName = finder.FindInterfaceName(destination)
opt.interfaceName = finder.FindInterfaceName(destination)
}
}
if interfaceName != "" {
if opt.interfaceName != "" {
bind := bindIfaceToDialer
if opt.fallbackBind {
bind = fallbackBindIfaceToDialer
}
if err := bind(interfaceName, dialer, network, destination); err != nil {
if err := bind(opt.interfaceName, dialer, network, destination); err != nil {
return nil, err
}
}
if opt.routingMark == 0 {
opt.routingMark = int(DefaultRoutingMark.Load())
}
if opt.routingMark != 0 {
bindMarkToDialer(opt.routingMark, dialer, network, destination)
}
@ -185,26 +177,26 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
return dialer.DialContext(ctx, network, address)
}
func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
return serialDialContext(ctx, network, ips, port, opt)
}
func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt)
}
func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
return parallelDialContext(ctx, network, ips, port, opt)
}
func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
if opt.prefer != 4 && opt.prefer != 6 {
return parallelDialContext(ctx, network, ips, port, opt)
}
return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt)
}
func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
ipv4s, ipv6s := resolver.SortationAddr(ips)
if len(ipv4s) == 0 && len(ipv6s) == 0 {
return nil, ErrorNoIpAddress
@ -285,7 +277,7 @@ loop:
return nil, errors.Join(errs...)
}
func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
if len(ips) == 0 {
return nil, ErrorNoIpAddress
}
@ -324,7 +316,7 @@ func parallelDialContext(ctx context.Context, network string, ips []netip.Addr,
return nil, os.ErrDeadlineExceeded
}
func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
if len(ips) == 0 {
return nil, ErrorNoIpAddress
}
@ -390,5 +382,5 @@ func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddr
func NewDialer(options ...Option) Dialer {
opt := applyOptions(options...)
return Dialer{Opt: *opt}
return Dialer{Opt: opt}
}

View file

@ -10,7 +10,6 @@ import (
)
var (
DefaultOptions []Option
DefaultInterface = atomic.NewTypedValue[string]("")
DefaultRoutingMark = atomic.NewInt32(0)
@ -117,9 +116,13 @@ func WithOption(o option) Option {
}
func IsZeroOptions(opts []Option) bool {
var opt option
for _, o := range opts {
return applyOptions(opts...) == option{}
}
func applyOptions(options ...Option) option {
opt := option{}
for _, o := range options {
o(&opt)
}
return opt == option{}
return opt
}