From 17c081a40ca283fe0826852309eb4df08f3ab2b4 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Sun, 27 Nov 2022 11:09:56 +0800 Subject: [PATCH] add support for hysteria udp port hopping (#269) * add support for hysteria udp port hopping * add ports field for hysteria * change method for udp connection Co-authored-by: geoleonsh --- adapter/outbound/hysteria.go | 20 +- transport/hysteria/conns/udp/hop.go | 350 +++++++++++++++++++++++++ transport/hysteria/core/client.go | 34 +-- transport/hysteria/transport/client.go | 36 ++- transport/hysteria/utils/misc.go | 7 + 5 files changed, 416 insertions(+), 31 deletions(-) create mode 100644 transport/hysteria/conns/udp/hop.go diff --git a/adapter/outbound/hysteria.go b/adapter/outbound/hysteria.go index aced8f8f..fd441dc1 100644 --- a/adapter/outbound/hysteria.go +++ b/adapter/outbound/hysteria.go @@ -37,8 +37,9 @@ const ( DefaultConnectionReceiveWindow = 67108864 // 64 MB/s DefaultMaxIncomingStreams = 1024 - DefaultALPN = "hysteria" - DefaultProtocol = "udp" + DefaultALPN = "hysteria" + DefaultProtocol = "udp" + DefaultHopInterval = 10 ) var rateStringRegexp = regexp.MustCompile(`^(\d+)\s*([KMGT]?)([Bb])ps$`) @@ -90,6 +91,7 @@ type HysteriaOption struct { Name string `proxy:"name"` Server string `proxy:"server"` Port int `proxy:"port"` + Ports string `proxy:"ports"` Protocol string `proxy:"protocol,omitempty"` ObfsProtocol string `proxy:"obfs-protocol,omitempty"` // compatible with Stash Up string `proxy:"up"` @@ -110,6 +112,7 @@ type HysteriaOption struct { ReceiveWindow int `proxy:"recv-window,omitempty"` DisableMTUDiscovery bool `proxy:"disable-mtu-discovery,omitempty"` FastOpen bool `proxy:"fast-open,omitempty"` + HopInterval int `proxy:"hop-interval"` } func (c *HysteriaOption) Speed() (uint64, uint64, error) { @@ -133,8 +136,13 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) { Timeout: 8 * time.Second, }, } + var addr string + if option.Ports == "" { + addr = net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) + } else { + addr = net.JoinHostPort(option.Server, option.Ports) + } - addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) serverName := option.Server if option.SNI != "" { serverName = option.SNI @@ -199,6 +207,10 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) { if option.Protocol == "" { option.Protocol = DefaultProtocol } + if option.HopInterval == 0 { + option.HopInterval = DefaultHopInterval + } + hopInterval := time.Duration(int64(option.HopInterval)) * time.Second if option.ReceiveWindow == 0 { quicConfig.InitialStreamReceiveWindow = DefaultStreamReceiveWindow / 10 quicConfig.MaxStreamReceiveWindow = DefaultStreamReceiveWindow @@ -236,7 +248,7 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) { client, err := core.NewClient( addr, option.Protocol, auth, tlsConfig, quicConfig, clientTransport, up, down, func(refBPS uint64) congestion.CongestionControl { return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) - }, obfuscator, option.FastOpen, + }, obfuscator, hopInterval, option.FastOpen, ) if err != nil { return nil, fmt.Errorf("hysteria %s create error: %w", addr, err) diff --git a/transport/hysteria/conns/udp/hop.go b/transport/hysteria/conns/udp/hop.go new file mode 100644 index 00000000..e4958821 --- /dev/null +++ b/transport/hysteria/conns/udp/hop.go @@ -0,0 +1,350 @@ +package udp + +import ( + "errors" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "syscall" + "time" + + "github.com/Dreamacro/clash/transport/hysteria/obfs" + "github.com/Dreamacro/clash/transport/hysteria/utils" +) + +const ( + packetQueueSize = 1024 +) + +// ObfsUDPHopClientPacketConn is the UDP port-hopping packet connection for client side. +// It hops to a different local & server port every once in a while. +type ObfsUDPHopClientPacketConn struct { + serverAddr net.Addr // Combined udpHopAddr + serverAddrs []net.Addr + hopInterval time.Duration + + obfs obfs.Obfuscator + + connMutex sync.RWMutex + prevConn net.PacketConn + currentConn net.PacketConn + addrIndex int + + readBufferSize int + writeBufferSize int + + recvQueue chan *udpPacket + closeChan chan struct{} + closed bool + + bufPool sync.Pool +} + +type udpHopAddr string + +func (a *udpHopAddr) Network() string { + return "udp-hop" +} + +func (a *udpHopAddr) String() string { + return string(*a) +} + +type udpPacket struct { + buf []byte + n int + addr net.Addr +} + +func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obfs obfs.Obfuscator, dialer utils.PacketDialer) (*ObfsUDPHopClientPacketConn, error) { + host, ports, err := parseAddr(server) + if err != nil { + return nil, err + } + // Resolve the server IP address, then attach the ports to UDP addresses + ip, err := dialer.RemoteAddr(host) + if err != nil { + return nil, err + } + serverAddrs := make([]net.Addr, len(ports)) + for i, port := range ports { + serverAddrs[i] = &net.UDPAddr{ + IP: net.ParseIP(ip.String()), + Port: int(port), + } + } + hopAddr := udpHopAddr(server) + conn := &ObfsUDPHopClientPacketConn{ + serverAddr: &hopAddr, + serverAddrs: serverAddrs, + hopInterval: hopInterval, + obfs: obfs, + addrIndex: rand.Intn(len(serverAddrs)), + recvQueue: make(chan *udpPacket, packetQueueSize), + closeChan: make(chan struct{}), + bufPool: sync.Pool{ + New: func() interface{} { + return make([]byte, udpBufferSize) + }, + }, + } + curConn, err := dialer.ListenPacket() + if err != nil { + return nil, err + } + if obfs != nil { + conn.currentConn = NewObfsUDPConn(curConn, obfs) + } else { + conn.currentConn = curConn + } + go conn.recvRoutine(conn.currentConn) + go conn.hopRoutine(dialer) + return conn, nil +} + +func (c *ObfsUDPHopClientPacketConn) recvRoutine(conn net.PacketConn) { + for { + buf := c.bufPool.Get().([]byte) + n, addr, err := conn.ReadFrom(buf) + if err != nil { + return + } + select { + case c.recvQueue <- &udpPacket{buf, n, addr}: + default: + // Drop the packet if the queue is full + c.bufPool.Put(buf) + } + } +} + +func (c *ObfsUDPHopClientPacketConn) hopRoutine(dialer utils.PacketDialer) { + ticker := time.NewTicker(c.hopInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + c.hop(dialer) + case <-c.closeChan: + return + } + } +} + +func (c *ObfsUDPHopClientPacketConn) hop(dialer utils.PacketDialer) { + c.connMutex.Lock() + defer c.connMutex.Unlock() + if c.closed { + return + } + newConn, err := dialer.ListenPacket() + if err != nil { + // Skip this hop if failed to listen + return + } + // Close prevConn, + // prevConn <- currentConn + // currentConn <- newConn + // update addrIndex + // + // We need to keep receiving packets from the previous connection, + // because otherwise there will be packet loss due to the time gap + // between we hop to a new port and the server acknowledges this change. + if c.prevConn != nil { + _ = c.prevConn.Close() // recvRoutine will exit on error + } + c.prevConn = c.currentConn + if c.obfs != nil { + c.currentConn = NewObfsUDPConn(newConn, c.obfs) + } else { + c.currentConn = newConn + } + // Set buffer sizes if previously set + if c.readBufferSize > 0 { + _ = trySetPacketConnReadBuffer(c.currentConn, c.readBufferSize) + } + if c.writeBufferSize > 0 { + _ = trySetPacketConnWriteBuffer(c.currentConn, c.writeBufferSize) + } + go c.recvRoutine(c.currentConn) + c.addrIndex = rand.Intn(len(c.serverAddrs)) +} + +func (c *ObfsUDPHopClientPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + for { + select { + case p := <-c.recvQueue: + /* + // Check if the packet is from one of the server addresses + for _, addr := range c.serverAddrs { + if addr.String() == p.addr.String() { + // Copy the packet to the buffer + n := copy(b, p.buf[:p.n]) + c.bufPool.Put(p.buf) + return n, c.serverAddr, nil + } + } + // Drop the packet, continue + c.bufPool.Put(p.buf) + */ + // The above code was causing performance issues when the range is large, + // so we skip the check for now. Should probably still check by using a map + // or something in the future. + n := copy(b, p.buf[:p.n]) + c.bufPool.Put(p.buf) + return n, c.serverAddr, nil + case <-c.closeChan: + return 0, nil, net.ErrClosed + } + // Ignore packets from other addresses + } +} + +func (c *ObfsUDPHopClientPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + c.connMutex.RLock() + defer c.connMutex.RUnlock() + /* + // Check if the address is the server address + if addr.String() != c.serverAddr.String() { + return 0, net.ErrWriteToConnected + } + */ + // Skip the check for now, always write to the server + return c.currentConn.WriteTo(b, c.serverAddrs[c.addrIndex]) +} + +func (c *ObfsUDPHopClientPacketConn) Close() error { + c.connMutex.Lock() + defer c.connMutex.Unlock() + if c.closed { + return nil + } + // Close prevConn and currentConn + // Close closeChan to unblock ReadFrom & hopRoutine + // Set closed flag to true to prevent double close + if c.prevConn != nil { + _ = c.prevConn.Close() + } + err := c.currentConn.Close() + close(c.closeChan) + c.closed = true + return err +} + +func (c *ObfsUDPHopClientPacketConn) LocalAddr() net.Addr { + c.connMutex.RLock() + defer c.connMutex.RUnlock() + return c.currentConn.LocalAddr() +} + +func (c *ObfsUDPHopClientPacketConn) SetReadDeadline(t time.Time) error { + // Not supported + return nil +} + +func (c *ObfsUDPHopClientPacketConn) SetWriteDeadline(t time.Time) error { + // Not supported + return nil +} + +func (c *ObfsUDPHopClientPacketConn) SetDeadline(t time.Time) error { + err := c.SetReadDeadline(t) + if err != nil { + return err + } + return c.SetWriteDeadline(t) +} + +func (c *ObfsUDPHopClientPacketConn) SetReadBuffer(bytes int) error { + c.connMutex.Lock() + defer c.connMutex.Unlock() + c.readBufferSize = bytes + if c.prevConn != nil { + _ = trySetPacketConnReadBuffer(c.prevConn, bytes) + } + return trySetPacketConnReadBuffer(c.currentConn, bytes) +} + +func (c *ObfsUDPHopClientPacketConn) SetWriteBuffer(bytes int) error { + c.connMutex.Lock() + defer c.connMutex.Unlock() + c.writeBufferSize = bytes + if c.prevConn != nil { + _ = trySetPacketConnWriteBuffer(c.prevConn, bytes) + } + return trySetPacketConnWriteBuffer(c.currentConn, bytes) +} + +func (c *ObfsUDPHopClientPacketConn) SyscallConn() (syscall.RawConn, error) { + c.connMutex.RLock() + defer c.connMutex.RUnlock() + sc, ok := c.currentConn.(syscall.Conn) + if !ok { + return nil, errors.New("not supported") + } + return sc.SyscallConn() +} + +func trySetPacketConnReadBuffer(pc net.PacketConn, bytes int) error { + sc, ok := pc.(interface { + SetReadBuffer(bytes int) error + }) + if ok { + return sc.SetReadBuffer(bytes) + } + return nil +} + +func trySetPacketConnWriteBuffer(pc net.PacketConn, bytes int) error { + sc, ok := pc.(interface { + SetWriteBuffer(bytes int) error + }) + if ok { + return sc.SetWriteBuffer(bytes) + } + return nil +} + +// parseAddr parses the multi-port server address and returns the host and ports. +// Supports both comma-separated single ports and dash-separated port ranges. +// Format: "host:port1,port2-port3,port4" +func parseAddr(addr string) (host string, ports []uint16, err error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return "", nil, err + } + portStrs := strings.Split(portStr, ",") + for _, portStr := range portStrs { + if strings.Contains(portStr, "-") { + // Port range + portRange := strings.Split(portStr, "-") + if len(portRange) != 2 { + return "", nil, net.InvalidAddrError("invalid port range") + } + start, err := strconv.ParseUint(portRange[0], 10, 16) + if err != nil { + return "", nil, net.InvalidAddrError("invalid port range") + } + end, err := strconv.ParseUint(portRange[1], 10, 16) + if err != nil { + return "", nil, net.InvalidAddrError("invalid port range") + } + if start > end { + start, end = end, start + } + for i := start; i <= end; i++ { + ports = append(ports, uint16(i)) + } + } else { + // Single port + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return "", nil, net.InvalidAddrError("invalid port") + } + ports = append(ports, uint16(port)) + } + } + return host, ports, nil +} diff --git a/transport/hysteria/core/client.go b/transport/hysteria/core/client.go index 8a19570b..e98a0c6b 100644 --- a/transport/hysteria/core/client.go +++ b/transport/hysteria/core/client.go @@ -6,18 +6,20 @@ import ( "crypto/tls" "errors" "fmt" - "github.com/Dreamacro/clash/transport/hysteria/obfs" - "github.com/Dreamacro/clash/transport/hysteria/pmtud_fix" - "github.com/Dreamacro/clash/transport/hysteria/transport" - "github.com/Dreamacro/clash/transport/hysteria/utils" - "github.com/lunixbochs/struc" - "github.com/metacubex/quic-go" - "github.com/metacubex/quic-go/congestion" "math/rand" "net" "strconv" "sync" "time" + + "github.com/lunixbochs/struc" + "github.com/metacubex/quic-go" + "github.com/metacubex/quic-go/congestion" + + "github.com/Dreamacro/clash/transport/hysteria/obfs" + "github.com/Dreamacro/clash/transport/hysteria/pmtud_fix" + "github.com/Dreamacro/clash/transport/hysteria/transport" + "github.com/Dreamacro/clash/transport/hysteria/utils" ) var ( @@ -45,12 +47,13 @@ type Client struct { udpSessionMutex sync.RWMutex udpSessionMap map[uint32]chan *udpMessage udpDefragger defragger + hopInterval time.Duration fastOpen bool } func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, transport *transport.ClientTransport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, - obfuscator obfs.Obfuscator,fastOpen bool) (*Client, error) { + obfuscator obfs.Obfuscator, hopInterval time.Duration, fastOpen bool) (*Client, error) { quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud_fix.DisablePathMTUDiscovery c := &Client{ transport: transport, @@ -63,13 +66,14 @@ func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.C obfuscator: obfuscator, tlsConfig: tlsConfig, quicConfig: quicConfig, - fastOpen: fastOpen, + hopInterval: hopInterval, + fastOpen: fastOpen, } return c, nil } -func (c *Client) connectToServer(dialer transport.PacketDialer) error { - qs, err := c.transport.QUICDial(c.protocol, c.serverAddr, c.tlsConfig, c.quicConfig, c.obfuscator, dialer) +func (c *Client) connectToServer(dialer utils.PacketDialer) error { + qs, err := c.transport.QUICDial(c.protocol, c.serverAddr, c.tlsConfig, c.quicConfig, c.obfuscator, c.hopInterval, dialer) if err != nil { return err } @@ -156,7 +160,7 @@ func (c *Client) handleMessage(qs quic.Connection) { } } -func (c *Client) openStreamWithReconnect(dialer transport.PacketDialer) (quic.Connection, quic.Stream, error) { +func (c *Client) openStreamWithReconnect(dialer utils.PacketDialer) (quic.Connection, quic.Stream, error) { c.reconnectMutex.Lock() defer c.reconnectMutex.Unlock() if c.closed { @@ -188,7 +192,7 @@ func (c *Client) openStreamWithReconnect(dialer transport.PacketDialer) (quic.Co return c.quicSession, &wrappedQUICStream{stream}, err } -func (c *Client) DialTCP(addr string, dialer transport.PacketDialer) (net.Conn, error) { +func (c *Client) DialTCP(addr string, dialer utils.PacketDialer) (net.Conn, error) { host, port, err := utils.SplitHostPort(addr) if err != nil { return nil, err @@ -227,11 +231,11 @@ func (c *Client) DialTCP(addr string, dialer transport.PacketDialer) (net.Conn, Orig: stream, PseudoLocalAddr: session.LocalAddr(), PseudoRemoteAddr: session.RemoteAddr(), - Established: !c.fastOpen, + Established: !c.fastOpen, }, nil } -func (c *Client) DialUDP(dialer transport.PacketDialer) (UDPConn, error) { +func (c *Client) DialUDP(dialer utils.PacketDialer) (UDPConn, error) { session, stream, err := c.openStreamWithReconnect(dialer) if err != nil { return nil, err diff --git a/transport/hysteria/transport/client.go b/transport/hysteria/transport/client.go index a48e9bf5..c30377a3 100644 --- a/transport/hysteria/transport/client.go +++ b/transport/hysteria/transport/client.go @@ -1,31 +1,41 @@ package transport import ( - "context" "crypto/tls" "fmt" + "net" + "strings" + "time" + + "github.com/metacubex/quic-go" + "github.com/Dreamacro/clash/transport/hysteria/conns/faketcp" "github.com/Dreamacro/clash/transport/hysteria/conns/udp" "github.com/Dreamacro/clash/transport/hysteria/conns/wechat" obfsPkg "github.com/Dreamacro/clash/transport/hysteria/obfs" - "github.com/metacubex/quic-go" - "net" + "github.com/Dreamacro/clash/transport/hysteria/utils" ) type ClientTransport struct { Dialer *net.Dialer } -func (ct *ClientTransport) quicPacketConn(proto string, server string, obfs obfsPkg.Obfuscator, dialer PacketDialer) (net.PacketConn, error) { +func (ct *ClientTransport) quicPacketConn(proto string, server string, obfs obfsPkg.Obfuscator, hopInterval time.Duration, dialer utils.PacketDialer) (net.PacketConn, error) { if len(proto) == 0 || proto == "udp" { conn, err := dialer.ListenPacket() if err != nil { return nil, err } if obfs != nil { + if isMultiPortAddr(server) { + return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, obfs, dialer) + } oc := udp.NewObfsUDPConn(conn, obfs) return oc, nil } else { + if isMultiPortAddr(server) { + return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, nil, dialer) + } return conn, nil } } else if proto == "wechat-video" { @@ -54,19 +64,13 @@ func (ct *ClientTransport) quicPacketConn(proto string, server string, obfs obfs } } -type PacketDialer interface { - ListenPacket() (net.PacketConn, error) - Context() context.Context - RemoteAddr(host string) (net.Addr, error) -} - -func (ct *ClientTransport) QUICDial(proto string, server string, tlsConfig *tls.Config, quicConfig *quic.Config, obfs obfsPkg.Obfuscator, dialer PacketDialer) (quic.Connection, error) { +func (ct *ClientTransport) QUICDial(proto string, server string, tlsConfig *tls.Config, quicConfig *quic.Config, obfs obfsPkg.Obfuscator, hopInterval time.Duration, dialer utils.PacketDialer) (quic.Connection, error) { serverUDPAddr, err := dialer.RemoteAddr(server) if err != nil { return nil, err } - pktConn, err := ct.quicPacketConn(proto, serverUDPAddr.String(), obfs, dialer) + pktConn, err := ct.quicPacketConn(proto, serverUDPAddr.String(), obfs, hopInterval, dialer) if err != nil { return nil, err } @@ -90,3 +94,11 @@ func (ct *ClientTransport) DialTCP(raddr *net.TCPAddr) (*net.TCPConn, error) { func (ct *ClientTransport) ListenUDP() (*net.UDPConn, error) { return net.ListenUDP("udp", nil) } + +func isMultiPortAddr(addr string) bool { + _, portStr, err := net.SplitHostPort(addr) + if err == nil && (strings.Contains(portStr, ",") || strings.Contains(portStr, "-")) { + return true + } + return false +} diff --git a/transport/hysteria/utils/misc.go b/transport/hysteria/utils/misc.go index 29c7cf0c..5d5159fc 100644 --- a/transport/hysteria/utils/misc.go +++ b/transport/hysteria/utils/misc.go @@ -1,6 +1,7 @@ package utils import ( + "context" "net" "strconv" ) @@ -40,3 +41,9 @@ func last(s string, b byte) int { } return i } + +type PacketDialer interface { + ListenPacket() (net.PacketConn, error) + Context() context.Context + RemoteAddr(host string) (net.Addr, error) +}