diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index 9af1751b..19875b8c 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -12,7 +12,9 @@ import ( "strconv" "strings" "sync" + "time" + "github.com/metacubex/mihomo/common/atomic" CN "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/proxydialer" @@ -23,6 +25,8 @@ import ( wireguard "github.com/metacubex/sing-wireguard" + "github.com/jpillora/backoff" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/debug" E "github.com/sagernet/sing/common/exceptions" @@ -125,6 +129,48 @@ func (option WireGuardPeerOption) Prefixes() ([]netip.Prefix, error) { return localPrefixes, nil } +type wgSingDialer struct { + proxydialer.SingDialer + errTimes atomic.Int64 + backoff *backoff.Backoff +} + +func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if d.errTimes.Load() > 10 { + select { + case <-time.After(d.backoff.Duration()): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + c, err := d.SingDialer.DialContext(ctx, network, destination) + if err != nil { + d.errTimes.Add(1) + return nil, err + } + d.errTimes.Store(0) + d.backoff.Reset() + return c, nil +} + +func (d *wgSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + if d.errTimes.Load() > 10 { + select { + case <-time.After(d.backoff.Duration()): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + c, err := d.SingDialer.ListenPacket(ctx, destination) + if err != nil { + d.errTimes.Add(1) + return nil, err + } + d.errTimes.Store(0) + d.backoff.Reset() + return c, nil +} + func NewWireGuard(option WireGuardOption) (*WireGuard, error) { outbound := &WireGuard{ Base: &Base{ @@ -136,7 +182,16 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { rmark: option.RoutingMark, prefer: C.NewDNSPrefer(option.IPVersion), }, - dialer: proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()), + dialer: &wgSingDialer{ + SingDialer: proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()), + errTimes: atomic.NewInt64(0), + backoff: &backoff.Backoff{ + Min: 10 * time.Millisecond, + Max: 1 * time.Second, + Factor: 2, + Jitter: true, + }, + }, } runtime.SetFinalizer(outbound, closeWireGuard)