From 2a0139e236922661506b28813bb32033042ddf3f Mon Sep 17 00:00:00 2001 From: keakon Date: Wed, 17 Jan 2024 20:00:01 +0800 Subject: [PATCH] support port hopping for hysteria 2 --- adapter/outbound/hysteria2.go | 172 ++-- go.mod | 8 +- go.sum | 17 +- transport/hysteria2/README.md | 1 + transport/hysteria2/app/cmd/client.go | 374 +++++++ transport/hysteria2/app/cmd/errors.go | 18 + transport/hysteria2/app/utils/bpsconv.go | 68 ++ transport/hysteria2/core/client/client.go | 316 ++++++ transport/hysteria2/core/client/config.go | 112 +++ transport/hysteria2/core/client/reconnect.go | 117 +++ transport/hysteria2/core/client/udp.go | 223 +++++ transport/hysteria2/core/errors/errors.go | 75 ++ .../core/internal/congestion/bbr/bandwidth.go | 27 + .../congestion/bbr/bandwidth_sampler.go | 874 ++++++++++++++++ .../internal/congestion/bbr/bbr_sender.go | 944 ++++++++++++++++++ .../core/internal/congestion/bbr/clock.go | 18 + .../bbr/packet_number_indexed_queue.go | 199 ++++ .../internal/congestion/bbr/ringbuffer.go | 118 +++ .../congestion/bbr/windowed_filter.go | 162 +++ .../core/internal/congestion/brutal/brutal.go | 181 ++++ .../core/internal/congestion/common/pacer.go | 95 ++ .../core/internal/congestion/utils.go | 18 + .../hysteria2/core/internal/frag/frag.go | 77 ++ .../hysteria2/core/internal/pmtud/avail.go | 7 + .../hysteria2/core/internal/pmtud/unavail.go | 13 + .../hysteria2/core/internal/protocol/http.go | 68 ++ .../core/internal/protocol/padding.go | 31 + .../hysteria2/core/internal/protocol/proxy.go | 255 +++++ .../hysteria2/core/internal/utils/atomic.go | 24 + .../hysteria2/core/internal/utils/qstream.go | 62 ++ .../hysteria2/extras/correctnet/correctnet.go | 92 ++ transport/hysteria2/extras/obfs/conn.go | 121 +++ transport/hysteria2/extras/obfs/salamander.go | 71 ++ .../hysteria2/extras/obfs/salamander_test.go | 45 + .../hysteria2/extras/transport/udphop/addr.go | 92 ++ .../hysteria2/extras/transport/udphop/conn.go | 288 ++++++ 36 files changed, 5282 insertions(+), 101 deletions(-) create mode 100644 transport/hysteria2/README.md create mode 100644 transport/hysteria2/app/cmd/client.go create mode 100644 transport/hysteria2/app/cmd/errors.go create mode 100644 transport/hysteria2/app/utils/bpsconv.go create mode 100644 transport/hysteria2/core/client/client.go create mode 100644 transport/hysteria2/core/client/config.go create mode 100644 transport/hysteria2/core/client/reconnect.go create mode 100644 transport/hysteria2/core/client/udp.go create mode 100644 transport/hysteria2/core/errors/errors.go create mode 100644 transport/hysteria2/core/internal/congestion/bbr/bandwidth.go create mode 100644 transport/hysteria2/core/internal/congestion/bbr/bandwidth_sampler.go create mode 100644 transport/hysteria2/core/internal/congestion/bbr/bbr_sender.go create mode 100644 transport/hysteria2/core/internal/congestion/bbr/clock.go create mode 100644 transport/hysteria2/core/internal/congestion/bbr/packet_number_indexed_queue.go create mode 100644 transport/hysteria2/core/internal/congestion/bbr/ringbuffer.go create mode 100644 transport/hysteria2/core/internal/congestion/bbr/windowed_filter.go create mode 100644 transport/hysteria2/core/internal/congestion/brutal/brutal.go create mode 100644 transport/hysteria2/core/internal/congestion/common/pacer.go create mode 100644 transport/hysteria2/core/internal/congestion/utils.go create mode 100644 transport/hysteria2/core/internal/frag/frag.go create mode 100644 transport/hysteria2/core/internal/pmtud/avail.go create mode 100644 transport/hysteria2/core/internal/pmtud/unavail.go create mode 100644 transport/hysteria2/core/internal/protocol/http.go create mode 100644 transport/hysteria2/core/internal/protocol/padding.go create mode 100644 transport/hysteria2/core/internal/protocol/proxy.go create mode 100644 transport/hysteria2/core/internal/utils/atomic.go create mode 100644 transport/hysteria2/core/internal/utils/qstream.go create mode 100644 transport/hysteria2/extras/correctnet/correctnet.go create mode 100644 transport/hysteria2/extras/obfs/conn.go create mode 100644 transport/hysteria2/extras/obfs/salamander.go create mode 100644 transport/hysteria2/extras/obfs/salamander_test.go create mode 100644 transport/hysteria2/extras/transport/udphop/addr.go create mode 100644 transport/hysteria2/extras/transport/udphop/conn.go diff --git a/adapter/outbound/hysteria2.go b/adapter/outbound/hysteria2.go index ddd5ccea..5ca32306 100644 --- a/adapter/outbound/hysteria2.go +++ b/adapter/outbound/hysteria2.go @@ -2,137 +2,123 @@ package outbound import ( "context" - "crypto/tls" - "errors" - "fmt" "net" - "runtime" "strconv" + "time" - CN "github.com/metacubex/mihomo/common/net" - "github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/dialer" - "github.com/metacubex/mihomo/component/proxydialer" C "github.com/metacubex/mihomo/constant" - tuicCommon "github.com/metacubex/mihomo/transport/tuic/common" - - "github.com/metacubex/sing-quic/hysteria2" - - M "github.com/sagernet/sing/common/metadata" + "github.com/metacubex/mihomo/transport/hysteria2/app/cmd" + hy2client "github.com/metacubex/mihomo/transport/hysteria2/core/client" ) -func init() { - hysteria2.SetCongestionController = tuicCommon.SetCongestionController -} +const minHopInterval = 5 +const defaultHopInterval = 30 type Hysteria2 struct { *Base option *Hysteria2Option - client *hysteria2.Client - dialer proxydialer.SingDialer + client hy2client.Client } type Hysteria2Option struct { BasicOption - Name string `proxy:"name"` - Server string `proxy:"server"` - Port int `proxy:"port"` - Up string `proxy:"up,omitempty"` - Down string `proxy:"down,omitempty"` - Password string `proxy:"password,omitempty"` - Obfs string `proxy:"obfs,omitempty"` - ObfsPassword string `proxy:"obfs-password,omitempty"` - SNI string `proxy:"sni,omitempty"` - SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` - Fingerprint string `proxy:"fingerprint,omitempty"` - ALPN []string `proxy:"alpn,omitempty"` - CustomCA string `proxy:"ca,omitempty"` - CustomCAString string `proxy:"ca-str,omitempty"` - CWND int `proxy:"cwnd,omitempty"` + Name string `proxy:"name"` + Server string `proxy:"server"` + Port uint16 `proxy:"port,omitempty"` + Ports string `proxy:"ports,omitempty"` + HopInterval time.Duration `proxy:"hop-interval,omitempty"` + Up string `proxy:"up"` + Down string `proxy:"down"` + Password string `proxy:"password,omitempty"` + Obfs string `proxy:"obfs,omitempty"` + ObfsPassword string `proxy:"obfs-password,omitempty"` + SNI string `proxy:"sni,omitempty"` + SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` + Fingerprint string `proxy:"fingerprint,omitempty"` + ALPN []string `proxy:"alpn,omitempty"` + CustomCA string `proxy:"ca,omitempty"` + CustomCAString string `proxy:"ca-str,omitempty"` + CWND int `proxy:"cwnd,omitempty"` + FastOpen bool `proxy:"fast-open,omitempty"` + Lazy bool `proxy:"lazy,omitempty"` } func (h *Hysteria2) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { - options := h.Base.DialOptions(opts...) - h.dialer.SetDialer(dialer.NewDialer(options...)) - c, err := h.client.DialConn(ctx, M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort)) + tcpConn, err := h.client.TCP(net.JoinHostPort(metadata.String(), strconv.Itoa(int(metadata.DstPort)))) if err != nil { return nil, err } - return NewConn(CN.NewRefConn(c, h), h), nil + + return NewConn(tcpConn, h), nil } func (h *Hysteria2) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { - options := h.Base.DialOptions(opts...) - h.dialer.SetDialer(dialer.NewDialer(options...)) - pc, err := h.client.ListenPacket(ctx) + udpConn, err := h.client.UDP() if err != nil { return nil, err } - if pc == nil { - return nil, errors.New("packetConn is nil") - } - return newPacketConn(CN.NewRefPacketConn(CN.NewThreadSafePacketConn(pc), h), h), nil -} - -func closeHysteria2(h *Hysteria2) { - if h.client != nil { - _ = h.client.CloseWithError(errors.New("proxy removed")) - } + return newPacketConn(udpConn, h), nil } func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) { - addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) - var salamanderPassword string - if len(option.Obfs) > 0 { - if option.ObfsPassword == "" { - return nil, errors.New("missing obfs password") - } - switch option.Obfs { - case hysteria2.ObfsTypeSalamander: - salamanderPassword = option.ObfsPassword - default: - return nil, fmt.Errorf("unknown obfs type: %s", option.Obfs) - } + var server string + if option.Ports != "" { + server = net.JoinHostPort(option.Server, option.Ports) + } else { + server = net.JoinHostPort(option.Server, strconv.Itoa(int(option.Port))) } - serverName := option.Server - if option.SNI != "" { - serverName = option.SNI + if option.HopInterval == 0 { + option.HopInterval = defaultHopInterval + } else if option.HopInterval < minHopInterval { + option.HopInterval = minHopInterval + } + option.HopInterval *= time.Second + + config := cmd.ClientConfig{ + Server: server, + Auth: option.Password, + Transport: cmd.ClientConfigTransport{ + UDP: cmd.ClientConfigTransportUDP{ + HopInterval: option.HopInterval, + }, + }, + TLS: cmd.ClientConfigTLS{ + SNI: option.SNI, + Insecure: option.SkipCertVerify, + PinSHA256: option.Fingerprint, + CA: option.CustomCA, + CAString: option.CustomCAString, + }, + FastOpen: option.FastOpen, + Lazy: option.Lazy, } - tlsConfig := &tls.Config{ - ServerName: serverName, - InsecureSkipVerify: option.SkipCertVerify, - MinVersion: tls.VersionTLS13, + if option.ObfsPassword != "" { + config.Obfs.Type = "salamander" + config.Obfs.Salamander.Password = option.ObfsPassword + } else if option.Obfs != "" { + config.Obfs.Type = "salamander" + config.Obfs.Salamander.Password = option.Obfs } - var err error - tlsConfig, err = ca.GetTLSConfig(tlsConfig, option.Fingerprint, option.CustomCA, option.CustomCAString) - if err != nil { - return nil, err + last := option.Up[len(option.Up)-1] + if '0' <= last && last <= '9' { + option.Up += "m" } - - if len(option.ALPN) > 0 { - tlsConfig.NextProtos = option.ALPN + config.Bandwidth.Up = option.Up + last = option.Down[len(option.Down)-1] + if '0' <= last && last <= '9' { + option.Down += "m" } + config.Bandwidth.Down = option.Down - singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()) - - clientOptions := hysteria2.ClientOptions{ - Context: context.TODO(), - Dialer: singDialer, - ServerAddress: M.ParseSocksaddrHostPort(option.Server, uint16(option.Port)), - SendBPS: StringToBps(option.Up), - ReceiveBPS: StringToBps(option.Down), - SalamanderPassword: salamanderPassword, - Password: option.Password, - TLSConfig: tlsConfig, - UDPDisabled: false, - CWND: option.CWND, - } - - client, err := hysteria2.NewClient(clientOptions) + client, err := hy2client.NewReconnectableClient( + config.Config, + func(c hy2client.Client, info *hy2client.HandshakeInfo, count int) {}, + option.Lazy) if err != nil { return nil, err } @@ -140,7 +126,7 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) { outbound := &Hysteria2{ Base: &Base{ name: option.Name, - addr: addr, + addr: server, tp: C.Hysteria2, udp: true, iface: option.Interface, @@ -149,9 +135,7 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) { }, option: &option, client: client, - dialer: singDialer, } - runtime.SetFinalizer(outbound, closeHysteria2) return outbound, nil } diff --git a/go.mod b/go.mod index f8252128..ae0b7d41 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/metacubex/mihomo -go 1.20 +go 1.21 require ( github.com/3andne/restls-client-go v0.1.6 @@ -65,11 +65,12 @@ require ( github.com/andybalholm/brotli v1.0.6 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/cloudflare/circl v1.3.6 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/ericlagergren/aegis v0.0.0-20230312195928-b4ce538b56f9 // indirect github.com/ericlagergren/polyval v0.0.0-20220411101811-e25bc10ba391 // indirect github.com/ericlagergren/siv v0.0.0-20220507050439-0b757b3aa5f1 // indirect github.com/ericlagergren/subtle v0.0.0-20220507045147-890d697da010 // indirect + github.com/frankban/quicktest v1.14.6 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gaukas/godicttls v0.0.4 // indirect github.com/go-ole/go-ole v1.3.0 // indirect @@ -89,7 +90,7 @@ require ( github.com/oasisprotocol/deoxysii v0.0.0-20220228165953-2091330c22b7 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/pierrec/lz4/v4 v4.1.14 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/quic-go/qpack v0.4.0 // indirect github.com/quic-go/qtls-go1-20 v0.4.1 // indirect @@ -110,6 +111,7 @@ require ( golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.16.0 // indirect + gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect ) replace github.com/sagernet/sing => github.com/metacubex/sing v0.0.0-20240111014253-f1818b6a82b2 diff --git a/go.sum b/go.sum index c3209df8..91608fb2 100644 --- a/go.sum +++ b/go.sum @@ -25,9 +25,11 @@ github.com/cloudflare/circl v1.3.6 h1:/xbKIqSHbZXHwkhbrhrt2YOHIwYJlXH94E3tI/gDlU github.com/cloudflare/circl v1.3.6/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8= github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/ericlagergren/aegis v0.0.0-20230312195928-b4ce538b56f9 h1:/5RkVc9Rc81XmMyVqawCiDyrBHZbLAZgTTCqou4mwj8= @@ -39,7 +41,8 @@ github.com/ericlagergren/siv v0.0.0-20220507050439-0b757b3aa5f1 h1:tlDMEdcPRQKBE github.com/ericlagergren/siv v0.0.0-20220507050439-0b757b3aa5f1/go.mod h1:4RfsapbGx2j/vU5xC/5/9qB3kn9Awp1YDiEnN43QrJ4= github.com/ericlagergren/subtle v0.0.0-20220507045147-890d697da010 h1:fuGucgPk5dN6wzfnxl3D0D3rVLw4v2SbBT9jb4VnxzA= github.com/ericlagergren/subtle v0.0.0-20220507045147-890d697da010/go.mod h1:JtBcj7sBuTTRupn7c2bFspMDIObMJsVK8TeUvpShPok= -github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gaukas/godicttls v0.0.4 h1:NlRaXb3J6hAnTmWdsEKb9bcSBD6BvcIjdGdeb0zfXbk= @@ -91,7 +94,9 @@ github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6K github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 h1:EnfXoSqDfSNJv0VBNqY/88RNnhSGYkrHaO0mmFGbVsc= @@ -140,9 +145,11 @@ github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq5 github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY= github.com/pierrec/lz4/v4 v4.1.14 h1:+fL8AQEZtz/ijeNnpduH0bROTu0O3NZAlPjQxGn8LwE= github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= @@ -153,6 +160,7 @@ github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1 github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5nfFs= github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a h1:+NkI2670SQpQWvkkD2QgdTuzQG263YZ+2emfpeyGqW0= github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a/go.mod h1:63s7jpZqcDAIpj8oI/1v4Izok+npJOHACFCU6+huCkM= github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE= @@ -270,7 +278,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/transport/hysteria2/README.md b/transport/hysteria2/README.md new file mode 100644 index 00000000..315ace2c --- /dev/null +++ b/transport/hysteria2/README.md @@ -0,0 +1 @@ +Copied from [hysteria](https://github.com/apernet/hysteria) v2.2.3 with a little changes. \ No newline at end of file diff --git a/transport/hysteria2/app/cmd/client.go b/transport/hysteria2/app/cmd/client.go new file mode 100644 index 00000000..27c4ef4d --- /dev/null +++ b/transport/hysteria2/app/cmd/client.go @@ -0,0 +1,374 @@ +package cmd + +import ( + "crypto/sha256" + "crypto/x509" + "encoding/hex" + "errors" + "net" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/metacubex/mihomo/transport/hysteria2/app/utils" + "github.com/metacubex/mihomo/transport/hysteria2/core/client" + "github.com/metacubex/mihomo/transport/hysteria2/extras/obfs" + "github.com/metacubex/mihomo/transport/hysteria2/extras/transport/udphop" +) + +type ClientConfig struct { + Server string `mapstructure:"server"` + Auth string `mapstructure:"auth"` + Transport ClientConfigTransport `mapstructure:"transport"` + Obfs ClientConfigObfs `mapstructure:"obfs"` + TLS ClientConfigTLS `mapstructure:"tls"` + QUIC clientConfigQUIC `mapstructure:"quic"` + Bandwidth ClientConfigBandwidth `mapstructure:"bandwidth"` + FastOpen bool `mapstructure:"fastOpen"` + Lazy bool `mapstructure:"lazy"` +} + +type ClientConfigTransportUDP struct { + HopInterval time.Duration `mapstructure:"hopInterval"` +} + +type ClientConfigTransport struct { + Type string `mapstructure:"type"` + UDP ClientConfigTransportUDP `mapstructure:"udp"` +} + +type ClientConfigObfsSalamander struct { + Password string `mapstructure:"password"` +} + +type ClientConfigObfs struct { + Type string `mapstructure:"type"` + Salamander ClientConfigObfsSalamander `mapstructure:"salamander"` +} + +type ClientConfigTLS struct { + SNI string `mapstructure:"sni"` + Insecure bool `mapstructure:"insecure"` + PinSHA256 string `mapstructure:"pinSHA256"` + CA string `mapstructure:"ca"` + CAString string `mapstructure:"ca-str"` +} + +type clientConfigQUIC struct { + InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` + MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` + InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` + MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` + MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` + KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"` + DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` +} + +type ClientConfigBandwidth struct { + Up string `mapstructure:"up"` + Down string `mapstructure:"down"` +} + +func (c *ClientConfig) fillServerAddr(hyConfig *client.Config) error { + if c.Server == "" { + return configError{Field: "server", Err: errors.New("server address is empty")} + } + var addr net.Addr + var err error + host, port, hostPort := parseServerAddrString(c.Server) + if !isPortHoppingPort(port) { + addr, err = net.ResolveUDPAddr("udp", hostPort) + } else { + addr, err = udphop.ResolveUDPHopAddr(hostPort) + } + if err != nil { + return configError{Field: "server", Err: err} + } + hyConfig.ServerAddr = addr + // Special handling for SNI + if c.TLS.SNI == "" { + // Use server hostname as SNI + hyConfig.TLSConfig.ServerName = host + } + return nil +} + +// fillConnFactory must be called after fillServerAddr, as we have different logic +// for ConnFactory depending on whether we have a port hopping address. +func (c *ClientConfig) fillConnFactory(hyConfig *client.Config) error { + // Inner PacketConn + var newFunc func(addr net.Addr) (net.PacketConn, error) + switch strings.ToLower(c.Transport.Type) { + case "", "udp": + if hyConfig.ServerAddr.Network() == "udphop" { + hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr) + newFunc = func(addr net.Addr) (net.PacketConn, error) { + return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval) + } + } else { + newFunc = func(addr net.Addr) (net.PacketConn, error) { + return net.ListenUDP("udp", nil) + } + } + default: + return configError{Field: "transport.type", Err: errors.New("unsupported transport type")} + } + // Obfuscation + var ob obfs.Obfuscator + var err error + switch strings.ToLower(c.Obfs.Type) { + case "", "plain": + // Keep it nil + case "salamander": + ob, err = obfs.NewSalamanderObfuscator([]byte(c.Obfs.Salamander.Password)) + if err != nil { + return configError{Field: "obfs.salamander.password", Err: err} + } + default: + return configError{Field: "obfs.type", Err: errors.New("unsupported obfuscation type")} + } + hyConfig.ConnFactory = &adaptiveConnFactory{ + NewFunc: newFunc, + Obfuscator: ob, + } + return nil +} + +func (c *ClientConfig) fillAuth(hyConfig *client.Config) error { + hyConfig.Auth = c.Auth + return nil +} + +func (c *ClientConfig) fillTLSConfig(hyConfig *client.Config) error { + if c.TLS.SNI != "" { + hyConfig.TLSConfig.ServerName = c.TLS.SNI + } + hyConfig.TLSConfig.InsecureSkipVerify = c.TLS.Insecure + if c.TLS.PinSHA256 != "" { + nHash := normalizeCertHash(c.TLS.PinSHA256) + hyConfig.TLSConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + for _, cert := range rawCerts { + hash := sha256.Sum256(cert) + hashHex := hex.EncodeToString(hash[:]) + if hashHex == nHash { + return nil + } + } + // No match + return errors.New("no certificate matches the pinned hash") + } + } + if c.TLS.CAString != "" || c.TLS.CA != "" { + var ca []byte + if c.TLS.CAString != "" { + ca = []byte(c.TLS.CAString) + } else { + var err error + ca, err = os.ReadFile(c.TLS.CA) + if err != nil { + return configError{Field: "tls.ca", Err: err} + } + } + cPool := x509.NewCertPool() + if !cPool.AppendCertsFromPEM(ca) { + return configError{Field: "tls.ca", Err: errors.New("failed to parse CA certificate")} + } + hyConfig.TLSConfig.RootCAs = cPool + } + return nil +} + +func (c *ClientConfig) fillQUICConfig(hyConfig *client.Config) error { + hyConfig.QUICConfig = client.QUICConfig{ + InitialStreamReceiveWindow: c.QUIC.InitStreamReceiveWindow, + MaxStreamReceiveWindow: c.QUIC.MaxStreamReceiveWindow, + InitialConnectionReceiveWindow: c.QUIC.InitConnectionReceiveWindow, + MaxConnectionReceiveWindow: c.QUIC.MaxConnectionReceiveWindow, + MaxIdleTimeout: c.QUIC.MaxIdleTimeout, + KeepAlivePeriod: c.QUIC.KeepAlivePeriod, + DisablePathMTUDiscovery: c.QUIC.DisablePathMTUDiscovery, + } + return nil +} + +func (c *ClientConfig) fillBandwidthConfig(hyConfig *client.Config) error { + // New core now allows users to omit bandwidth values and use built-in congestion control + var err error + if c.Bandwidth.Up != "" { + hyConfig.BandwidthConfig.MaxTx, err = utils.ConvBandwidth(c.Bandwidth.Up) + if err != nil { + return configError{Field: "bandwidth.up", Err: err} + } + } + if c.Bandwidth.Down != "" { + hyConfig.BandwidthConfig.MaxRx, err = utils.ConvBandwidth(c.Bandwidth.Down) + if err != nil { + return configError{Field: "bandwidth.down", Err: err} + } + } + return nil +} + +func (c *ClientConfig) fillFastOpen(hyConfig *client.Config) error { + hyConfig.FastOpen = c.FastOpen + return nil +} + +// URI generates a URI for sharing the config with others. +// Note that only the bare minimum of information required to +// connect to the server is included in the URI, specifically: +// - server address +// - authentication +// - obfuscation type +// - obfuscation password +// - TLS SNI +// - TLS insecure +// - TLS pinned SHA256 hash (normalized) +func (c *ClientConfig) URI() string { + q := url.Values{} + switch strings.ToLower(c.Obfs.Type) { + case "salamander": + q.Set("obfs", "salamander") + q.Set("obfs-password", c.Obfs.Salamander.Password) + } + if c.TLS.SNI != "" { + q.Set("sni", c.TLS.SNI) + } + if c.TLS.Insecure { + q.Set("insecure", "1") + } + if c.TLS.PinSHA256 != "" { + q.Set("pinSHA256", normalizeCertHash(c.TLS.PinSHA256)) + } + var user *url.Userinfo + if c.Auth != "" { + // We need to handle the special case of user:pass pairs + rs := strings.SplitN(c.Auth, ":", 2) + if len(rs) == 2 { + user = url.UserPassword(rs[0], rs[1]) + } else { + user = url.User(c.Auth) + } + } + u := url.URL{ + Scheme: "hysteria2", + User: user, + Host: c.Server, + Path: "/", + RawQuery: q.Encode(), + } + return u.String() +} + +// parseURI tries to parse the server address field as a URI, +// and fills the config with the information contained in the URI. +// Returns whether the server address field is a valid URI. +// This allows a user to use put a URI as the server address and +// omit the fields that are already contained in the URI. +func (c *ClientConfig) parseURI() bool { + u, err := url.Parse(c.Server) + if err != nil { + return false + } + if u.Scheme != "hysteria2" && u.Scheme != "hy2" { + return false + } + if u.User != nil { + c.Auth = u.User.String() + } + c.Server = u.Host + q := u.Query() + if obfsType := q.Get("obfs"); obfsType != "" { + c.Obfs.Type = obfsType + switch strings.ToLower(obfsType) { + case "salamander": + c.Obfs.Salamander.Password = q.Get("obfs-password") + } + } + if sni := q.Get("sni"); sni != "" { + c.TLS.SNI = sni + } + if insecure, err := strconv.ParseBool(q.Get("insecure")); err == nil { + c.TLS.Insecure = insecure + } + if pinSHA256 := q.Get("pinSHA256"); pinSHA256 != "" { + c.TLS.PinSHA256 = pinSHA256 + } + return true +} + +// Config validates the fields and returns a ready-to-use Hysteria client config +func (c *ClientConfig) Config() (*client.Config, error) { + c.parseURI() + hyConfig := &client.Config{} + fillers := []func(*client.Config) error{ + c.fillServerAddr, + c.fillConnFactory, + c.fillAuth, + c.fillTLSConfig, + c.fillQUICConfig, + c.fillBandwidthConfig, + c.fillFastOpen, + } + for _, f := range fillers { + if err := f(hyConfig); err != nil { + return nil, err + } + } + return hyConfig, nil +} + +type clientModeRunner struct { + ModeMap map[string]func() error +} + +func (r *clientModeRunner) Add(name string, f func() error) { + if r.ModeMap == nil { + r.ModeMap = make(map[string]func() error) + } + r.ModeMap[name] = f +} + +// parseServerAddrString parses server address string. +// Server address can be in either "host:port" or "host" format (in which case we assume port 443). +func parseServerAddrString(addrStr string) (host, port, hostPort string) { + h, p, err := net.SplitHostPort(addrStr) + if err != nil { + return addrStr, "443", net.JoinHostPort(addrStr, "443") + } + return h, p, addrStr +} + +// isPortHoppingPort returns whether the port string is a port hopping port. +// We consider a port string to be a port hopping port if it contains "-" or ",". +func isPortHoppingPort(port string) bool { + return strings.Contains(port, "-") || strings.Contains(port, ",") +} + +// normalizeCertHash normalizes a certificate hash string. +// It converts all characters to lowercase and removes possible separators such as ":" and "-". +func normalizeCertHash(hash string) string { + r := strings.ToLower(hash) + r = strings.ReplaceAll(r, ":", "") + r = strings.ReplaceAll(r, "-", "") + return r +} + +type adaptiveConnFactory struct { + NewFunc func(addr net.Addr) (net.PacketConn, error) + Obfuscator obfs.Obfuscator // nil if no obfuscation +} + +func (f *adaptiveConnFactory) New(addr net.Addr) (net.PacketConn, error) { + if f.Obfuscator == nil { + return f.NewFunc(addr) + } else { + conn, err := f.NewFunc(addr) + if err != nil { + return nil, err + } + return obfs.WrapPacketConn(conn, f.Obfuscator), nil + } +} diff --git a/transport/hysteria2/app/cmd/errors.go b/transport/hysteria2/app/cmd/errors.go new file mode 100644 index 00000000..3d0234aa --- /dev/null +++ b/transport/hysteria2/app/cmd/errors.go @@ -0,0 +1,18 @@ +package cmd + +import ( + "fmt" +) + +type configError struct { + Field string + Err error +} + +func (e configError) Error() string { + return fmt.Sprintf("invalid config: %s: %s", e.Field, e.Err) +} + +func (e configError) Unwrap() error { + return e.Err +} diff --git a/transport/hysteria2/app/utils/bpsconv.go b/transport/hysteria2/app/utils/bpsconv.go new file mode 100644 index 00000000..7cad5580 --- /dev/null +++ b/transport/hysteria2/app/utils/bpsconv.go @@ -0,0 +1,68 @@ +package utils + +import ( + "errors" + "fmt" + "strconv" + "strings" +) + +const ( + Byte = 1 + Kilobyte = Byte * 1000 + Megabyte = Kilobyte * 1000 + Gigabyte = Megabyte * 1000 + Terabyte = Gigabyte * 1000 +) + +// StringToBps converts a string to a bandwidth value in bytes per second. +// E.g. "100 Mbps", "512 kbps", "1g" are all valid. +func StringToBps(s string) (uint64, error) { + s = strings.ToLower(strings.TrimSpace(s)) + spl := 0 + for i, c := range s { + if c < '0' || c > '9' { + spl = i + break + } + } + if spl == 0 { + // No unit or no value + return 0, errors.New("invalid format") + } + v, err := strconv.ParseUint(s[:spl], 10, 64) + if err != nil { + return 0, err + } + unit := strings.TrimSpace(s[spl:]) + + switch strings.ToLower(unit) { + case "b", "bps": + return v * Byte / 8, nil + case "k", "kb", "kbps": + return v * Kilobyte / 8, nil + case "m", "mb", "mbps": + return v * Megabyte / 8, nil + case "g", "gb", "gbps": + return v * Gigabyte / 8, nil + case "t", "tb", "tbps": + return v * Terabyte / 8, nil + default: + return 0, errors.New("unsupported unit") + } +} + +// ConvBandwidth handles both string and int types for bandwidth. +// When using string, it will be parsed as a bandwidth string with units. +// When using int, it will be parsed as a raw bandwidth in bytes per second. +// It does NOT support float types. +func ConvBandwidth(bw interface{}) (uint64, error) { + switch bwT := bw.(type) { + case string: + return StringToBps(bwT) + case int: + return uint64(bwT), nil + default: + return 0, fmt.Errorf("invalid type %T for bandwidth", bwT) + } +} diff --git a/transport/hysteria2/core/client/client.go b/transport/hysteria2/core/client/client.go new file mode 100644 index 00000000..aab5a04f --- /dev/null +++ b/transport/hysteria2/core/client/client.go @@ -0,0 +1,316 @@ +package client + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "net/url" + "time" + + coreErrs "github.com/metacubex/mihomo/transport/hysteria2/core/errors" + "github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion" + "github.com/metacubex/mihomo/transport/hysteria2/core/internal/protocol" + "github.com/metacubex/mihomo/transport/hysteria2/core/internal/utils" + + "github.com/metacubex/quic-go" + "github.com/metacubex/quic-go/http3" +) + +const ( + closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError + closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError +) + +type Client interface { + TCP(addr string) (net.Conn, error) + UDP() (HyUDPConn, error) + Close() error +} + +type HyUDPConn interface { + Receive() ([]byte, string, error) + Send([]byte, string) error + net.PacketConn +} + +type HandshakeInfo struct { + UDPEnabled bool + Tx uint64 // 0 if using BBR +} + +func NewClient(config *Config) (Client, *HandshakeInfo, error) { + if err := config.verifyAndFill(); err != nil { + return nil, nil, err + } + c := &clientImpl{ + config: config, + } + info, err := c.connect() + if err != nil { + return nil, nil, err + } + return c, info, nil +} + +type clientImpl struct { + config *Config + + pktConn net.PacketConn + conn quic.Connection + + udpSM *udpSessionManager +} + +func (c *clientImpl) connect() (*HandshakeInfo, error) { + pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr) + if err != nil { + return nil, err + } + // Convert config to TLS config & QUIC config + tlsConfig := &tls.Config{ + ServerName: c.config.TLSConfig.ServerName, + InsecureSkipVerify: c.config.TLSConfig.InsecureSkipVerify, + VerifyPeerCertificate: c.config.TLSConfig.VerifyPeerCertificate, + RootCAs: c.config.TLSConfig.RootCAs, + } + quicConfig := &quic.Config{ + InitialStreamReceiveWindow: c.config.QUICConfig.InitialStreamReceiveWindow, + MaxStreamReceiveWindow: c.config.QUICConfig.MaxStreamReceiveWindow, + InitialConnectionReceiveWindow: c.config.QUICConfig.InitialConnectionReceiveWindow, + MaxConnectionReceiveWindow: c.config.QUICConfig.MaxConnectionReceiveWindow, + MaxIdleTimeout: c.config.QUICConfig.MaxIdleTimeout, + KeepAlivePeriod: c.config.QUICConfig.KeepAlivePeriod, + DisablePathMTUDiscovery: c.config.QUICConfig.DisablePathMTUDiscovery, + EnableDatagrams: true, + } + // Prepare RoundTripper + var conn quic.EarlyConnection + rt := &http3.RoundTripper{ + EnableDatagrams: true, + TLSClientConfig: tlsConfig, + QuicConfig: quicConfig, + Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + qc, err := quic.DialEarly(ctx, pktConn, c.config.ServerAddr, tlsCfg, cfg) + if err != nil { + return nil, err + } + conn = qc + return qc, nil + }, + } + // Send auth HTTP request + req := &http.Request{ + Method: http.MethodPost, + URL: &url.URL{ + Scheme: "https", + Host: protocol.URLHost, + Path: protocol.URLPath, + }, + Header: make(http.Header), + } + protocol.AuthRequestToHeader(req.Header, protocol.AuthRequest{ + Auth: c.config.Auth, + Rx: c.config.BandwidthConfig.MaxRx, + }) + resp, err := rt.RoundTrip(req) + if err != nil { + if conn != nil { + _ = conn.CloseWithError(closeErrCodeProtocolError, "") + } + _ = pktConn.Close() + return nil, coreErrs.ConnectError{Err: err} + } + if resp.StatusCode != protocol.StatusAuthOK { + _ = conn.CloseWithError(closeErrCodeProtocolError, "") + _ = pktConn.Close() + return nil, coreErrs.AuthError{StatusCode: resp.StatusCode} + } + // Auth OK + authResp := protocol.AuthResponseFromHeader(resp.Header) + var actualTx uint64 + if authResp.RxAuto { + // Server asks client to use bandwidth detection, + // ignore local bandwidth config and use BBR + congestion.UseBBR(conn) + } else { + // actualTx = min(serverRx, clientTx) + actualTx = authResp.Rx + if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx { + // Server doesn't have a limit, or our clientTx is smaller than serverRx + actualTx = c.config.BandwidthConfig.MaxTx + } + if actualTx > 0 { + congestion.UseBrutal(conn, actualTx) + } else { + // We don't know our own bandwidth either, use BBR + congestion.UseBBR(conn) + } + } + _ = resp.Body.Close() + + c.pktConn = pktConn + c.conn = conn + if authResp.UDPEnabled { + c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn}) + } + return &HandshakeInfo{ + UDPEnabled: authResp.UDPEnabled, + Tx: actualTx, + }, nil +} + +// openStream wraps the stream with QStream, which handles Close() properly +func (c *clientImpl) openStream() (quic.Stream, error) { + stream, err := c.conn.OpenStream() + if err != nil { + return nil, err + } + return &utils.QStream{Stream: stream}, nil +} + +func (c *clientImpl) TCP(addr string) (net.Conn, error) { + stream, err := c.openStream() + if err != nil { + return nil, wrapIfConnectionClosed(err) + } + // Send request + err = protocol.WriteTCPRequest(stream, addr) + if err != nil { + _ = stream.Close() + return nil, wrapIfConnectionClosed(err) + } + if c.config.FastOpen { + // Don't wait for the response when fast open is enabled. + // Return the connection immediately, defer the response handling + // to the first Read() call. + return &tcpConn{ + Orig: stream, + PseudoLocalAddr: c.conn.LocalAddr(), + PseudoRemoteAddr: c.conn.RemoteAddr(), + Established: false, + }, nil + } + // Read response + ok, msg, err := protocol.ReadTCPResponse(stream) + if err != nil { + _ = stream.Close() + return nil, wrapIfConnectionClosed(err) + } + if !ok { + _ = stream.Close() + return nil, coreErrs.DialError{Message: msg} + } + return &tcpConn{ + Orig: stream, + PseudoLocalAddr: c.conn.LocalAddr(), + PseudoRemoteAddr: c.conn.RemoteAddr(), + Established: true, + }, nil +} + +func (c *clientImpl) UDP() (HyUDPConn, error) { + if c.udpSM == nil { + return nil, coreErrs.DialError{Message: "UDP not enabled"} + } + return c.udpSM.NewUDP() +} + +func (c *clientImpl) Close() error { + _ = c.conn.CloseWithError(closeErrCodeOK, "") + _ = c.pktConn.Close() + return nil +} + +// wrapIfConnectionClosed checks if the error returned by quic-go +// indicates that the QUIC connection has been permanently closed, +// and if so, wraps the error with coreErrs.ClosedError. +// PITFALL: sometimes quic-go has "internal errors" that are not net.Error, +// but we still need to treat them as ClosedError. +func wrapIfConnectionClosed(err error) error { + netErr, ok := err.(net.Error) + if !ok || !netErr.Temporary() { + return coreErrs.ClosedError{Err: err} + } else { + return err + } +} + +type tcpConn struct { + Orig quic.Stream + PseudoLocalAddr net.Addr + PseudoRemoteAddr net.Addr + Established bool +} + +func (c *tcpConn) Read(b []byte) (n int, err error) { + if !c.Established { + // Read response + ok, msg, err := protocol.ReadTCPResponse(c.Orig) + if err != nil { + return 0, err + } + if !ok { + return 0, coreErrs.DialError{Message: msg} + } + c.Established = true + } + return c.Orig.Read(b) +} + +func (c *tcpConn) Write(b []byte) (n int, err error) { + return c.Orig.Write(b) +} + +func (c *tcpConn) Close() error { + return c.Orig.Close() +} + +func (c *tcpConn) LocalAddr() net.Addr { + return c.PseudoLocalAddr +} + +func (c *tcpConn) RemoteAddr() net.Addr { + return c.PseudoRemoteAddr +} + +func (c *tcpConn) SetDeadline(t time.Time) error { + return c.Orig.SetDeadline(t) +} + +func (c *tcpConn) SetReadDeadline(t time.Time) error { + return c.Orig.SetReadDeadline(t) +} + +func (c *tcpConn) SetWriteDeadline(t time.Time) error { + return c.Orig.SetWriteDeadline(t) +} + +type udpIOImpl struct { + Conn quic.Connection +} + +func (io *udpIOImpl) ReceiveMessage() (*protocol.UDPMessage, error) { + for { + msg, err := io.Conn.ReceiveDatagram(context.Background()) + if err != nil { + // Connection error, this will stop the session manager + return nil, err + } + udpMsg, err := protocol.ParseUDPMessage(msg) + if err != nil { + // Invalid message, this is fine - just wait for the next + continue + } + return udpMsg, nil + } +} + +func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error { + msgN := msg.Serialize(buf) + if msgN < 0 { + // Message larger than buffer, silent drop + return nil + } + return io.Conn.SendDatagram(buf[:msgN]) +} diff --git a/transport/hysteria2/core/client/config.go b/transport/hysteria2/core/client/config.go new file mode 100644 index 00000000..0ec7f4f3 --- /dev/null +++ b/transport/hysteria2/core/client/config.go @@ -0,0 +1,112 @@ +package client + +import ( + "crypto/x509" + "net" + "time" + + "github.com/metacubex/mihomo/transport/hysteria2/core/errors" + "github.com/metacubex/mihomo/transport/hysteria2/core/internal/pmtud" +) + +const ( + defaultStreamReceiveWindow = 8388608 // 8MB + defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB + defaultMaxIdleTimeout = 30 * time.Second + defaultKeepAlivePeriod = 10 * time.Second +) + +type Config struct { + ConnFactory ConnFactory + ServerAddr net.Addr + Auth string + TLSConfig TLSConfig + QUICConfig QUICConfig + BandwidthConfig BandwidthConfig + FastOpen bool + + filled bool // whether the fields have been verified and filled +} + +// verifyAndFill fills the fields that are not set by the user with default values when possible, +// and returns an error if the user has not set a required field or has set an invalid value. +func (c *Config) verifyAndFill() error { + if c.filled { + return nil + } + if c.ConnFactory == nil { + c.ConnFactory = &udpConnFactory{} + } + if c.ServerAddr == nil { + return errors.ConfigError{Field: "ServerAddr", Reason: "must be set"} + } + if c.QUICConfig.InitialStreamReceiveWindow == 0 { + c.QUICConfig.InitialStreamReceiveWindow = defaultStreamReceiveWindow + } else if c.QUICConfig.InitialStreamReceiveWindow < 16384 { + return errors.ConfigError{Field: "QUICConfig.InitialStreamReceiveWindow", Reason: "must be at least 16384"} + } + if c.QUICConfig.MaxStreamReceiveWindow == 0 { + c.QUICConfig.MaxStreamReceiveWindow = defaultStreamReceiveWindow + } else if c.QUICConfig.MaxStreamReceiveWindow < 16384 { + return errors.ConfigError{Field: "QUICConfig.MaxStreamReceiveWindow", Reason: "must be at least 16384"} + } + if c.QUICConfig.InitialConnectionReceiveWindow == 0 { + c.QUICConfig.InitialConnectionReceiveWindow = defaultConnReceiveWindow + } else if c.QUICConfig.InitialConnectionReceiveWindow < 16384 { + return errors.ConfigError{Field: "QUICConfig.InitialConnectionReceiveWindow", Reason: "must be at least 16384"} + } + if c.QUICConfig.MaxConnectionReceiveWindow == 0 { + c.QUICConfig.MaxConnectionReceiveWindow = defaultConnReceiveWindow + } else if c.QUICConfig.MaxConnectionReceiveWindow < 16384 { + return errors.ConfigError{Field: "QUICConfig.MaxConnectionReceiveWindow", Reason: "must be at least 16384"} + } + if c.QUICConfig.MaxIdleTimeout == 0 { + c.QUICConfig.MaxIdleTimeout = defaultMaxIdleTimeout + } else if c.QUICConfig.MaxIdleTimeout < 4*time.Second || c.QUICConfig.MaxIdleTimeout > 120*time.Second { + return errors.ConfigError{Field: "QUICConfig.MaxIdleTimeout", Reason: "must be between 4s and 120s"} + } + if c.QUICConfig.KeepAlivePeriod == 0 { + c.QUICConfig.KeepAlivePeriod = defaultKeepAlivePeriod + } else if c.QUICConfig.KeepAlivePeriod < 2*time.Second || c.QUICConfig.KeepAlivePeriod > 60*time.Second { + return errors.ConfigError{Field: "QUICConfig.KeepAlivePeriod", Reason: "must be between 2s and 60s"} + } + c.QUICConfig.DisablePathMTUDiscovery = c.QUICConfig.DisablePathMTUDiscovery || pmtud.DisablePathMTUDiscovery + + c.filled = true + return nil +} + +type ConnFactory interface { + New(net.Addr) (net.PacketConn, error) +} + +type udpConnFactory struct{} + +func (f *udpConnFactory) New(addr net.Addr) (net.PacketConn, error) { + return net.ListenUDP("udp", nil) +} + +// TLSConfig contains the TLS configuration fields that we want to expose to the user. +type TLSConfig struct { + ServerName string + InsecureSkipVerify bool + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + RootCAs *x509.CertPool +} + +// QUICConfig contains the QUIC configuration fields that we want to expose to the user. +type QUICConfig struct { + InitialStreamReceiveWindow uint64 + MaxStreamReceiveWindow uint64 + InitialConnectionReceiveWindow uint64 + MaxConnectionReceiveWindow uint64 + MaxIdleTimeout time.Duration + KeepAlivePeriod time.Duration + DisablePathMTUDiscovery bool // The server may still override this to true on unsupported platforms. +} + +// BandwidthConfig describes the maximum bandwidth that the server can use, in bytes per second. +type BandwidthConfig struct { + MaxTx uint64 + MaxRx uint64 +} diff --git a/transport/hysteria2/core/client/reconnect.go b/transport/hysteria2/core/client/reconnect.go new file mode 100644 index 00000000..10f0c7d3 --- /dev/null +++ b/transport/hysteria2/core/client/reconnect.go @@ -0,0 +1,117 @@ +package client + +import ( + "net" + "sync" + + coreErrs "github.com/metacubex/mihomo/transport/hysteria2/core/errors" +) + +// reconnectableClientImpl is a wrapper of Client, which can reconnect when the connection is closed, +// except when the caller explicitly calls Close() to permanently close this client. +type reconnectableClientImpl struct { + configFunc func() (*Config, error) // called before connecting + connectedFunc func(Client, *HandshakeInfo, int) // called when successfully connected + client Client + count int + m sync.Mutex + closed bool // permanent close +} + +// NewReconnectableClient creates a reconnectable client. +// If lazy is true, the client will not connect until the first call to TCP() or UDP(). +// We use a function for config mainly to delay config evaluation +// (which involves DNS resolution) until the actual connection attempt. +func NewReconnectableClient(configFunc func() (*Config, error), connectedFunc func(Client, *HandshakeInfo, int), lazy bool) (Client, error) { + rc := &reconnectableClientImpl{ + configFunc: configFunc, + connectedFunc: connectedFunc, + } + if !lazy { + if err := rc.reconnect(); err != nil { + return nil, err + } + } + return rc, nil +} + +func (rc *reconnectableClientImpl) reconnect() error { + if rc.client != nil { + _ = rc.client.Close() + } + var info *HandshakeInfo + config, err := rc.configFunc() + if err != nil { + return err + } + rc.client, info, err = NewClient(config) + if err != nil { + return err + } else { + rc.count++ + if rc.connectedFunc != nil { + rc.connectedFunc(rc, info, rc.count) + } + return nil + } +} + +func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) { + rc.m.Lock() + defer rc.m.Unlock() + if rc.closed { + return nil, coreErrs.ClosedError{} + } + if rc.client == nil { + // No active connection, connect first + if err := rc.reconnect(); err != nil { + return nil, err + } + } + conn, err := rc.client.TCP(addr) + if _, ok := err.(coreErrs.ClosedError); ok { + // Connection closed, reconnect + if err := rc.reconnect(); err != nil { + return nil, err + } + return rc.client.TCP(addr) + } else { + // OK or some other temporary error + return conn, err + } +} + +func (rc *reconnectableClientImpl) UDP() (HyUDPConn, error) { + rc.m.Lock() + defer rc.m.Unlock() + if rc.closed { + return nil, coreErrs.ClosedError{} + } + if rc.client == nil { + // No active connection, connect first + if err := rc.reconnect(); err != nil { + return nil, err + } + } + conn, err := rc.client.UDP() + if _, ok := err.(coreErrs.ClosedError); ok { + // Connection closed, reconnect + if err := rc.reconnect(); err != nil { + return nil, err + } + return rc.client.UDP() + } else { + // OK or some other temporary error + return conn, err + } +} + +func (rc *reconnectableClientImpl) Close() error { + rc.m.Lock() + defer rc.m.Unlock() + rc.closed = true + if rc.client != nil { + return rc.client.Close() + } + return nil +} diff --git a/transport/hysteria2/core/client/udp.go b/transport/hysteria2/core/client/udp.go new file mode 100644 index 00000000..b53496db --- /dev/null +++ b/transport/hysteria2/core/client/udp.go @@ -0,0 +1,223 @@ +package client + +import ( + "errors" + "io" + "math/rand" + "net" + "sync" + "time" + + coreErrs "github.com/metacubex/mihomo/transport/hysteria2/core/errors" + "github.com/metacubex/mihomo/transport/hysteria2/core/internal/frag" + "github.com/metacubex/mihomo/transport/hysteria2/core/internal/protocol" + "github.com/metacubex/quic-go" + M "github.com/sagernet/sing/common/metadata" +) + +const ( + udpMessageChanSize = 1024 +) + +type udpIO interface { + ReceiveMessage() (*protocol.UDPMessage, error) + SendMessage([]byte, *protocol.UDPMessage) error +} + +type udpConn struct { + ID uint32 + D *frag.Defragger + ReceiveCh chan *protocol.UDPMessage + SendBuf []byte + SendFunc func([]byte, *protocol.UDPMessage) error + CloseFunc func() + Closed bool +} + +func (u *udpConn) Receive() ([]byte, string, error) { + for { + msg := <-u.ReceiveCh + if msg == nil { + // Closed + return nil, "", io.EOF + } + dfMsg := u.D.Feed(msg) + if dfMsg == nil { + // Incomplete message, wait for more + continue + } + return dfMsg.Data, dfMsg.Addr, nil + } +} + +func (u *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + bs, addrStr, err := u.Receive() + n = copy(p, bs) + addr = M.ParseSocksaddr(addrStr).UDPAddr() + return +} + +// Send is not thread-safe, as it uses a shared SendBuf. +func (u *udpConn) Send(data []byte, addr string) error { + // Try no frag first + msg := &protocol.UDPMessage{ + SessionID: u.ID, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: addr, + Data: data, + } + err := u.SendFunc(u.SendBuf, msg) + var errTooLarge quic.ErrMessageTooLarge + if errors.As(err, &errTooLarge) { + // Message too large, try fragmentation + msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 + fMsgs := frag.FragUDPMessage(msg, int(errTooLarge)) + for _, fMsg := range fMsgs { + err := u.SendFunc(u.SendBuf, &fMsg) + if err != nil { + return err + } + } + return nil + } else { + return err + } +} + +func (u *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + err = u.Send(p, M.SocksaddrFromNet(addr).String()) + if err != nil { + return + } + n = len(p) + return +} + +func (u *udpConn) Close() error { + u.CloseFunc() + return nil +} + +func (u *udpConn) LocalAddr() net.Addr { + // a fake implementation to satisfy net.PacketConn + return nil +} + +func (u *udpConn) SetDeadline(t time.Time) error { + // a fake implementation to satisfy net.PacketConn + return nil +} + +func (u *udpConn) SetReadDeadline(t time.Time) error { + // a fake implementation to satisfy net.PacketConn + return nil +} + +func (u *udpConn) SetWriteDeadline(t time.Time) error { + // a fake implementation to satisfy net.PacketConn + return nil +} + +type udpSessionManager struct { + io udpIO + + mutex sync.RWMutex + m map[uint32]*udpConn + nextID uint32 + + closed bool +} + +func newUDPSessionManager(io udpIO) *udpSessionManager { + m := &udpSessionManager{ + io: io, + m: make(map[uint32]*udpConn), + nextID: 1, + } + go m.run() + return m +} + +func (m *udpSessionManager) run() error { + defer m.closeCleanup() + for { + msg, err := m.io.ReceiveMessage() + if err != nil { + return err + } + m.feed(msg) + } +} + +func (m *udpSessionManager) closeCleanup() { + m.mutex.Lock() + defer m.mutex.Unlock() + + for _, conn := range m.m { + m.close(conn) + } + m.closed = true +} + +func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + conn, ok := m.m[msg.SessionID] + if !ok { + // Ignore message from unknown session + return + } + + select { + case conn.ReceiveCh <- msg: + // OK + default: + // Channel full, drop the message + } +} + +// NewUDP creates a new UDP session. +func (m *udpSessionManager) NewUDP() (HyUDPConn, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closed { + return nil, coreErrs.ClosedError{} + } + + id := m.nextID + m.nextID++ + + conn := &udpConn{ + ID: id, + D: &frag.Defragger{}, + ReceiveCh: make(chan *protocol.UDPMessage, udpMessageChanSize), + SendBuf: make([]byte, protocol.MaxUDPSize), + SendFunc: m.io.SendMessage, + } + conn.CloseFunc = func() { + m.mutex.Lock() + defer m.mutex.Unlock() + m.close(conn) + } + m.m[id] = conn + + return conn, nil +} + +func (m *udpSessionManager) close(conn *udpConn) { + if !conn.Closed { + conn.Closed = true + close(conn.ReceiveCh) + delete(m.m, conn.ID) + } +} + +func (m *udpSessionManager) Count() int { + m.mutex.RLock() + defer m.mutex.RUnlock() + return len(m.m) +} diff --git a/transport/hysteria2/core/errors/errors.go b/transport/hysteria2/core/errors/errors.go new file mode 100644 index 00000000..cb691184 --- /dev/null +++ b/transport/hysteria2/core/errors/errors.go @@ -0,0 +1,75 @@ +package errors + +import ( + "fmt" + "strconv" +) + +// ConfigError is returned when a configuration field is invalid. +type ConfigError struct { + Field string + Reason string +} + +func (c ConfigError) Error() string { + return fmt.Sprintf("invalid config: %s: %s", c.Field, c.Reason) +} + +// ConnectError is returned when the client fails to connect to the server. +type ConnectError struct { + Err error +} + +func (c ConnectError) Error() string { + return "connect error: " + c.Err.Error() +} + +func (c ConnectError) Unwrap() error { + return c.Err +} + +// AuthError is returned when the client fails to authenticate with the server. +type AuthError struct { + StatusCode int +} + +func (a AuthError) Error() string { + return "authentication error, HTTP status code: " + strconv.Itoa(a.StatusCode) +} + +// DialError is returned when the server rejects the client's dial request. +// This applies to both TCP and UDP. +type DialError struct { + Message string +} + +func (c DialError) Error() string { + return "dial error: " + c.Message +} + +// ClosedError is returned when the client attempts to use a closed connection. +type ClosedError struct { + Err error // Can be nil +} + +func (c ClosedError) Error() string { + if c.Err == nil { + return "connection closed" + } else { + return "connection closed: " + c.Err.Error() + } +} + +func (c ClosedError) Unwrap() error { + return c.Err +} + +// ProtocolError is returned when the server/client runs into an unexpected +// or malformed request/response/message. +type ProtocolError struct { + Message string +} + +func (p ProtocolError) Error() string { + return "protocol error: " + p.Message +} diff --git a/transport/hysteria2/core/internal/congestion/bbr/bandwidth.go b/transport/hysteria2/core/internal/congestion/bbr/bandwidth.go new file mode 100644 index 00000000..97dbe07e --- /dev/null +++ b/transport/hysteria2/core/internal/congestion/bbr/bandwidth.go @@ -0,0 +1,27 @@ +package bbr + +import ( + "math" + "time" + + "github.com/metacubex/quic-go/congestion" +) + +const ( + infBandwidth = Bandwidth(math.MaxUint64) +) + +// Bandwidth of a connection +type Bandwidth uint64 + +const ( + // BitsPerSecond is 1 bit per second + BitsPerSecond Bandwidth = 1 + // BytesPerSecond is 1 byte per second + BytesPerSecond = 8 * BitsPerSecond +) + +// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta +func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth { + return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond +} diff --git a/transport/hysteria2/core/internal/congestion/bbr/bandwidth_sampler.go b/transport/hysteria2/core/internal/congestion/bbr/bandwidth_sampler.go new file mode 100644 index 00000000..a9c1b1f7 --- /dev/null +++ b/transport/hysteria2/core/internal/congestion/bbr/bandwidth_sampler.go @@ -0,0 +1,874 @@ +package bbr + +import ( + "math" + "time" + + "github.com/metacubex/quic-go/congestion" +) + +const ( + infRTT = time.Duration(math.MaxInt64) + defaultConnectionStateMapQueueSize = 256 + defaultCandidatesBufferSize = 256 +) + +type roundTripCount uint64 + +// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned +// to the caller when the packet is acked or lost. +type sendTimeState struct { + // Whether other states in this object is valid. + isValid bool + // Whether the sender is app limited at the time the packet was sent. + // App limited bandwidth sample might be artificially low because the sender + // did not have enough data to send in order to saturate the link. + isAppLimited bool + // Total number of sent bytes at the time the packet was sent. + // Includes the packet itself. + totalBytesSent congestion.ByteCount + // Total number of acked bytes at the time the packet was sent. + totalBytesAcked congestion.ByteCount + // Total number of lost bytes at the time the packet was sent. + totalBytesLost congestion.ByteCount + // Total number of inflight bytes at the time the packet was sent. + // Includes the packet itself. + // It should be equal to |total_bytes_sent| minus the sum of + // |total_bytes_acked|, |total_bytes_lost| and total neutered bytes. + bytesInFlight congestion.ByteCount +} + +func newSendTimeState( + isAppLimited bool, + totalBytesSent congestion.ByteCount, + totalBytesAcked congestion.ByteCount, + totalBytesLost congestion.ByteCount, + bytesInFlight congestion.ByteCount, +) *sendTimeState { + return &sendTimeState{ + isValid: true, + isAppLimited: isAppLimited, + totalBytesSent: totalBytesSent, + totalBytesAcked: totalBytesAcked, + totalBytesLost: totalBytesLost, + bytesInFlight: bytesInFlight, + } +} + +type extraAckedEvent struct { + // The excess bytes acknowlwedged in the time delta for this event. + extraAcked congestion.ByteCount + + // The bytes acknowledged and time delta from the event. + bytesAcked congestion.ByteCount + timeDelta time.Duration + // The round trip of the event. + round roundTripCount +} + +func maxExtraAckedEventFunc(a, b extraAckedEvent) int { + if a.extraAcked > b.extraAcked { + return 1 + } else if a.extraAcked < b.extraAcked { + return -1 + } + return 0 +} + +// BandwidthSample +type bandwidthSample struct { + // The bandwidth at that particular sample. Zero if no valid bandwidth sample + // is available. + bandwidth Bandwidth + // The RTT measurement at this particular sample. Zero if no RTT sample is + // available. Does not correct for delayed ack time. + rtt time.Duration + // |send_rate| is computed from the current packet being acked('P') and an + // earlier packet that is acked before P was sent. + sendRate Bandwidth + // States captured when the packet was sent. + stateAtSend sendTimeState +} + +func newBandwidthSample() *bandwidthSample { + return &bandwidthSample{ + sendRate: infBandwidth, + } +} + +// MaxAckHeightTracker is part of the BandwidthSampler. It is called after every +// ack event to keep track the degree of ack aggregation(a.k.a "ack height"). +type maxAckHeightTracker struct { + // Tracks the maximum number of bytes acked faster than the estimated + // bandwidth. + maxAckHeightFilter *WindowedFilter[extraAckedEvent, roundTripCount] + // The time this aggregation started and the number of bytes acked during it. + aggregationEpochStartTime time.Time + aggregationEpochBytes congestion.ByteCount + // The last sent packet number before the current aggregation epoch started. + lastSentPacketNumberBeforeEpoch congestion.PacketNumber + // The number of ack aggregation epochs ever started, including the ongoing + // one. Stats only. + numAckAggregationEpochs uint64 + ackAggregationBandwidthThreshold float64 + startNewAggregationEpochAfterFullRound bool + reduceExtraAckedOnBandwidthIncrease bool +} + +func newMaxAckHeightTracker(windowLength roundTripCount) *maxAckHeightTracker { + return &maxAckHeightTracker{ + maxAckHeightFilter: NewWindowedFilter(windowLength, maxExtraAckedEventFunc), + lastSentPacketNumberBeforeEpoch: invalidPacketNumber, + ackAggregationBandwidthThreshold: 1.0, + } +} + +func (m *maxAckHeightTracker) Get() congestion.ByteCount { + return m.maxAckHeightFilter.GetBest().extraAcked +} + +func (m *maxAckHeightTracker) Update( + bandwidthEstimate Bandwidth, + isNewMaxBandwidth bool, + roundTripCount roundTripCount, + lastSentPacketNumber congestion.PacketNumber, + lastAckedPacketNumber congestion.PacketNumber, + ackTime time.Time, + bytesAcked congestion.ByteCount, +) congestion.ByteCount { + forceNewEpoch := false + + if m.reduceExtraAckedOnBandwidthIncrease && isNewMaxBandwidth { + // Save and clear existing entries. + best := m.maxAckHeightFilter.GetBest() + secondBest := m.maxAckHeightFilter.GetSecondBest() + thirdBest := m.maxAckHeightFilter.GetThirdBest() + m.maxAckHeightFilter.Clear() + + // Reinsert the heights into the filter after recalculating. + expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, best.timeDelta) + if expectedBytesAcked < best.bytesAcked { + best.extraAcked = best.bytesAcked - expectedBytesAcked + m.maxAckHeightFilter.Update(best, best.round) + } + expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, secondBest.timeDelta) + if expectedBytesAcked < secondBest.bytesAcked { + secondBest.extraAcked = secondBest.bytesAcked - expectedBytesAcked + m.maxAckHeightFilter.Update(secondBest, secondBest.round) + } + expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, thirdBest.timeDelta) + if expectedBytesAcked < thirdBest.bytesAcked { + thirdBest.extraAcked = thirdBest.bytesAcked - expectedBytesAcked + m.maxAckHeightFilter.Update(thirdBest, thirdBest.round) + } + } + + // If any packet sent after the start of the epoch has been acked, start a new + // epoch. + if m.startNewAggregationEpochAfterFullRound && + m.lastSentPacketNumberBeforeEpoch != invalidPacketNumber && + lastAckedPacketNumber != invalidPacketNumber && + lastAckedPacketNumber > m.lastSentPacketNumberBeforeEpoch { + forceNewEpoch = true + } + if m.aggregationEpochStartTime.IsZero() || forceNewEpoch { + m.aggregationEpochBytes = bytesAcked + m.aggregationEpochStartTime = ackTime + m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber + m.numAckAggregationEpochs++ + return 0 + } + + // Compute how many bytes are expected to be delivered, assuming max bandwidth + // is correct. + aggregationDelta := ackTime.Sub(m.aggregationEpochStartTime) + expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, aggregationDelta) + // Reset the current aggregation epoch as soon as the ack arrival rate is less + // than or equal to the max bandwidth. + if m.aggregationEpochBytes <= congestion.ByteCount(m.ackAggregationBandwidthThreshold*float64(expectedBytesAcked)) { + // Reset to start measuring a new aggregation epoch. + m.aggregationEpochBytes = bytesAcked + m.aggregationEpochStartTime = ackTime + m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber + m.numAckAggregationEpochs++ + return 0 + } + + m.aggregationEpochBytes += bytesAcked + + // Compute how many extra bytes were delivered vs max bandwidth. + extraBytesAcked := m.aggregationEpochBytes - expectedBytesAcked + newEvent := extraAckedEvent{ + extraAcked: expectedBytesAcked, + bytesAcked: m.aggregationEpochBytes, + timeDelta: aggregationDelta, + } + m.maxAckHeightFilter.Update(newEvent, roundTripCount) + return extraBytesAcked +} + +func (m *maxAckHeightTracker) SetFilterWindowLength(length roundTripCount) { + m.maxAckHeightFilter.SetWindowLength(length) +} + +func (m *maxAckHeightTracker) Reset(newHeight congestion.ByteCount, newTime roundTripCount) { + newEvent := extraAckedEvent{ + extraAcked: newHeight, + round: newTime, + } + m.maxAckHeightFilter.Reset(newEvent, newTime) +} + +func (m *maxAckHeightTracker) SetAckAggregationBandwidthThreshold(threshold float64) { + m.ackAggregationBandwidthThreshold = threshold +} + +func (m *maxAckHeightTracker) SetStartNewAggregationEpochAfterFullRound(value bool) { + m.startNewAggregationEpochAfterFullRound = value +} + +func (m *maxAckHeightTracker) SetReduceExtraAckedOnBandwidthIncrease(value bool) { + m.reduceExtraAckedOnBandwidthIncrease = value +} + +func (m *maxAckHeightTracker) AckAggregationBandwidthThreshold() float64 { + return m.ackAggregationBandwidthThreshold +} + +func (m *maxAckHeightTracker) NumAckAggregationEpochs() uint64 { + return m.numAckAggregationEpochs +} + +// AckPoint represents a point on the ack line. +type ackPoint struct { + ackTime time.Time + totalBytesAcked congestion.ByteCount +} + +// RecentAckPoints maintains the most recent 2 ack points at distinct times. +type recentAckPoints struct { + ackPoints [2]ackPoint +} + +func (r *recentAckPoints) Update(ackTime time.Time, totalBytesAcked congestion.ByteCount) { + if ackTime.Before(r.ackPoints[1].ackTime) { + r.ackPoints[1].ackTime = ackTime + } else if ackTime.After(r.ackPoints[1].ackTime) { + r.ackPoints[0] = r.ackPoints[1] + r.ackPoints[1].ackTime = ackTime + } + + r.ackPoints[1].totalBytesAcked = totalBytesAcked +} + +func (r *recentAckPoints) Clear() { + r.ackPoints[0] = ackPoint{} + r.ackPoints[1] = ackPoint{} +} + +func (r *recentAckPoints) MostRecentPoint() *ackPoint { + return &r.ackPoints[1] +} + +func (r *recentAckPoints) LessRecentPoint() *ackPoint { + if r.ackPoints[0].totalBytesAcked != 0 { + return &r.ackPoints[0] + } + + return &r.ackPoints[1] +} + +// ConnectionStateOnSentPacket represents the information about a sent packet +// and the state of the connection at the moment the packet was sent, +// specifically the information about the most recently acknowledged packet at +// that moment. +type connectionStateOnSentPacket struct { + // Time at which the packet is sent. + sentTime time.Time + // Size of the packet. + size congestion.ByteCount + // The value of |totalBytesSentAtLastAckedPacket| at the time the + // packet was sent. + totalBytesSentAtLastAckedPacket congestion.ByteCount + // The value of |lastAckedPacketSentTime| at the time the packet was + // sent. + lastAckedPacketSentTime time.Time + // The value of |lastAckedPacketAckTime| at the time the packet was + // sent. + lastAckedPacketAckTime time.Time + // Send time states that are returned to the congestion controller when the + // packet is acked or lost. + sendTimeState sendTimeState +} + +// Snapshot constructor. Records the current state of the bandwidth +// sampler. +// |bytes_in_flight| is the bytes in flight right after the packet is sent. +func newConnectionStateOnSentPacket( + sentTime time.Time, + size congestion.ByteCount, + bytesInFlight congestion.ByteCount, + sampler *bandwidthSampler, +) *connectionStateOnSentPacket { + return &connectionStateOnSentPacket{ + sentTime: sentTime, + size: size, + totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket, + lastAckedPacketSentTime: sampler.lastAckedPacketSentTime, + lastAckedPacketAckTime: sampler.lastAckedPacketAckTime, + sendTimeState: *newSendTimeState( + sampler.isAppLimited, + sampler.totalBytesSent, + sampler.totalBytesAcked, + sampler.totalBytesLost, + bytesInFlight, + ), + } +} + +// BandwidthSampler keeps track of sent and acknowledged packets and outputs a +// bandwidth sample for every packet acknowledged. The samples are taken for +// individual packets, and are not filtered; the consumer has to filter the +// bandwidth samples itself. In certain cases, the sampler will locally severely +// underestimate the bandwidth, hence a maximum filter with a size of at least +// one RTT is recommended. +// +// This class bases its samples on the slope of two curves: the number of bytes +// sent over time, and the number of bytes acknowledged as received over time. +// It produces a sample of both slopes for every packet that gets acknowledged, +// based on a slope between two points on each of the corresponding curves. Note +// that due to the packet loss, the number of bytes on each curve might get +// further and further away from each other, meaning that it is not feasible to +// compare byte values coming from different curves with each other. +// +// The obvious points for measuring slope sample are the ones corresponding to +// the packet that was just acknowledged. Let us denote them as S_1 (point at +// which the current packet was sent) and A_1 (point at which the current packet +// was acknowledged). However, taking a slope requires two points on each line, +// so estimating bandwidth requires picking a packet in the past with respect to +// which the slope is measured. +// +// For that purpose, BandwidthSampler always keeps track of the most recently +// acknowledged packet, and records it together with every outgoing packet. +// When a packet gets acknowledged (A_1), it has not only information about when +// it itself was sent (S_1), but also the information about the latest +// acknowledged packet right before it was sent (S_0 and A_0). +// +// Based on that data, send and ack rate are estimated as: +// +// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0)) +// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0)) +// +// Here, the ack rate is intuitively the rate we want to treat as bandwidth. +// However, in certain cases (e.g. ack compression) the ack rate at a point may +// end up higher than the rate at which the data was originally sent, which is +// not indicative of the real bandwidth. Hence, we use the send rate as an upper +// bound, and the sample value is +// +// rate_sample = min(send_rate, ack_rate) +// +// An important edge case handled by the sampler is tracking the app-limited +// samples. There are multiple meaning of "app-limited" used interchangeably, +// hence it is important to understand and to be able to distinguish between +// them. +// +// Meaning 1: connection state. The connection is said to be app-limited when +// there is no outstanding data to send. This means that certain bandwidth +// samples in the future would not be an accurate indication of the link +// capacity, and it is important to inform consumer about that. Whenever +// connection becomes app-limited, the sampler is notified via OnAppLimited() +// method. +// +// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth +// sampler becomes notified about the connection being app-limited, it enters +// app-limited phase. In that phase, all *sent* packets are marked as +// app-limited. Note that the connection itself does not have to be +// app-limited during the app-limited phase, and in fact it will not be +// (otherwise how would it send packets?). The boolean flag below indicates +// whether the sampler is in that phase. +// +// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is +// sent during the app-limited phase, the resulting sample related to the +// packet will be marked as app-limited. +// +// With the terminology issue out of the way, let us consider the question of +// what kind of situation it addresses. +// +// Consider a scenario where we first send packets 1 to 20 at a regular +// bandwidth, and then immediately run out of data. After a few seconds, we send +// packets 21 to 60, and only receive ack for 21 between sending packets 40 and +// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0 +// we use to compute the slope is going to be packet 20, a few seconds apart +// from the current packet, hence the resulting estimate would be extremely low +// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21, +// meaning that the bandwidth sample would exclude the quiescence. +// +// Based on the analysis of that scenario, we implement the following rule: once +// OnAppLimited() is called, all sent packets will produce app-limited samples +// up until an ack for a packet that was sent after OnAppLimited() was called. +// Note that while the scenario above is not the only scenario when the +// connection is app-limited, the approach works in other cases too. + +type congestionEventSample struct { + // The maximum bandwidth sample from all acked packets. + // QuicBandwidth::Zero() if no samples are available. + sampleMaxBandwidth Bandwidth + // Whether |sample_max_bandwidth| is from a app-limited sample. + sampleIsAppLimited bool + // The minimum rtt sample from all acked packets. + // QuicTime::Delta::Infinite() if no samples are available. + sampleRtt time.Duration + // For each packet p in acked packets, this is the max value of INFLIGHT(p), + // where INFLIGHT(p) is the number of bytes acked while p is inflight. + sampleMaxInflight congestion.ByteCount + // The send state of the largest packet in acked_packets, unless it is + // empty. If acked_packets is empty, it's the send state of the largest + // packet in lost_packets. + lastPacketSendState sendTimeState + // The number of extra bytes acked from this ack event, compared to what is + // expected from the flow's bandwidth. Larger value means more ack + // aggregation. + extraAcked congestion.ByteCount +} + +func newCongestionEventSample() *congestionEventSample { + return &congestionEventSample{ + sampleRtt: infRTT, + } +} + +type bandwidthSampler struct { + // The total number of congestion controlled bytes sent during the connection. + totalBytesSent congestion.ByteCount + + // The total number of congestion controlled bytes which were acknowledged. + totalBytesAcked congestion.ByteCount + + // The total number of congestion controlled bytes which were lost. + totalBytesLost congestion.ByteCount + + // The total number of congestion controlled bytes which have been neutered. + totalBytesNeutered congestion.ByteCount + + // The value of |total_bytes_sent_| at the time the last acknowledged packet + // was sent. Valid only when |last_acked_packet_sent_time_| is valid. + totalBytesSentAtLastAckedPacket congestion.ByteCount + + // The time at which the last acknowledged packet was sent. Set to + // QuicTime::Zero() if no valid timestamp is available. + lastAckedPacketSentTime time.Time + + // The time at which the most recent packet was acknowledged. + lastAckedPacketAckTime time.Time + + // The most recently sent packet. + lastSentPacket congestion.PacketNumber + + // The most recently acked packet. + lastAckedPacket congestion.PacketNumber + + // Indicates whether the bandwidth sampler is currently in an app-limited + // phase. + isAppLimited bool + + // The packet that will be acknowledged after this one will cause the sampler + // to exit the app-limited phase. + endOfAppLimitedPhase congestion.PacketNumber + + // Record of the connection state at the point where each packet in flight was + // sent, indexed by the packet number. + connectionStateMap *packetNumberIndexedQueue[connectionStateOnSentPacket] + + recentAckPoints recentAckPoints + a0Candidates RingBuffer[ackPoint] + + // Maximum number of tracked packets. + maxTrackedPackets congestion.ByteCount + + maxAckHeightTracker *maxAckHeightTracker + totalBytesAckedAfterLastAckEvent congestion.ByteCount + + // True if connection option 'BSAO' is set. + overestimateAvoidance bool + + // True if connection option 'BBRB' is set. + limitMaxAckHeightTrackerBySendRate bool +} + +func newBandwidthSampler(maxAckHeightTrackerWindowLength roundTripCount) *bandwidthSampler { + b := &bandwidthSampler{ + maxAckHeightTracker: newMaxAckHeightTracker(maxAckHeightTrackerWindowLength), + connectionStateMap: newPacketNumberIndexedQueue[connectionStateOnSentPacket](defaultConnectionStateMapQueueSize), + lastSentPacket: invalidPacketNumber, + lastAckedPacket: invalidPacketNumber, + endOfAppLimitedPhase: invalidPacketNumber, + } + + b.a0Candidates.Init(defaultCandidatesBufferSize) + + return b +} + +func (b *bandwidthSampler) MaxAckHeight() congestion.ByteCount { + return b.maxAckHeightTracker.Get() +} + +func (b *bandwidthSampler) NumAckAggregationEpochs() uint64 { + return b.maxAckHeightTracker.NumAckAggregationEpochs() +} + +func (b *bandwidthSampler) SetMaxAckHeightTrackerWindowLength(length roundTripCount) { + b.maxAckHeightTracker.SetFilterWindowLength(length) +} + +func (b *bandwidthSampler) ResetMaxAckHeightTracker(newHeight congestion.ByteCount, newTime roundTripCount) { + b.maxAckHeightTracker.Reset(newHeight, newTime) +} + +func (b *bandwidthSampler) SetStartNewAggregationEpochAfterFullRound(value bool) { + b.maxAckHeightTracker.SetStartNewAggregationEpochAfterFullRound(value) +} + +func (b *bandwidthSampler) SetLimitMaxAckHeightTrackerBySendRate(value bool) { + b.limitMaxAckHeightTrackerBySendRate = value +} + +func (b *bandwidthSampler) SetReduceExtraAckedOnBandwidthIncrease(value bool) { + b.maxAckHeightTracker.SetReduceExtraAckedOnBandwidthIncrease(value) +} + +func (b *bandwidthSampler) EnableOverestimateAvoidance() { + if b.overestimateAvoidance { + return + } + + b.overestimateAvoidance = true + b.maxAckHeightTracker.SetAckAggregationBandwidthThreshold(2.0) +} + +func (b *bandwidthSampler) IsOverestimateAvoidanceEnabled() bool { + return b.overestimateAvoidance +} + +func (b *bandwidthSampler) OnPacketSent( + sentTime time.Time, + packetNumber congestion.PacketNumber, + bytes congestion.ByteCount, + bytesInFlight congestion.ByteCount, + isRetransmittable bool, +) { + b.lastSentPacket = packetNumber + + if !isRetransmittable { + return + } + + b.totalBytesSent += bytes + + // If there are no packets in flight, the time at which the new transmission + // opens can be treated as the A_0 point for the purpose of bandwidth + // sampling. This underestimates bandwidth to some extent, and produces some + // artificially low samples for most packets in flight, but it provides with + // samples at important points where we would not have them otherwise, most + // importantly at the beginning of the connection. + if bytesInFlight == 0 { + b.lastAckedPacketAckTime = sentTime + if b.overestimateAvoidance { + b.recentAckPoints.Clear() + b.recentAckPoints.Update(sentTime, b.totalBytesAcked) + b.a0Candidates.Clear() + b.a0Candidates.PushBack(*b.recentAckPoints.MostRecentPoint()) + } + b.totalBytesSentAtLastAckedPacket = b.totalBytesSent + + // In this situation ack compression is not a concern, set send rate to + // effectively infinite. + b.lastAckedPacketSentTime = sentTime + } + + b.connectionStateMap.Emplace(packetNumber, newConnectionStateOnSentPacket( + sentTime, + bytes, + bytesInFlight+bytes, + b, + )) +} + +func (b *bandwidthSampler) OnCongestionEvent( + ackTime time.Time, + ackedPackets []congestion.AckedPacketInfo, + lostPackets []congestion.LostPacketInfo, + maxBandwidth Bandwidth, + estBandwidthUpperBound Bandwidth, + roundTripCount roundTripCount, +) congestionEventSample { + eventSample := newCongestionEventSample() + + var lastLostPacketSendState sendTimeState + + for _, p := range lostPackets { + sendState := b.OnPacketLost(p.PacketNumber, p.BytesLost) + if sendState.isValid { + lastLostPacketSendState = sendState + } + } + + if len(ackedPackets) == 0 { + // Only populate send state for a loss-only event. + eventSample.lastPacketSendState = lastLostPacketSendState + return *eventSample + } + + var lastAckedPacketSendState sendTimeState + var maxSendRate Bandwidth + + for _, p := range ackedPackets { + sample := b.onPacketAcknowledged(ackTime, p.PacketNumber) + if !sample.stateAtSend.isValid { + continue + } + + lastAckedPacketSendState = sample.stateAtSend + + if sample.rtt != 0 { + eventSample.sampleRtt = min(eventSample.sampleRtt, sample.rtt) + } + if sample.bandwidth > eventSample.sampleMaxBandwidth { + eventSample.sampleMaxBandwidth = sample.bandwidth + eventSample.sampleIsAppLimited = sample.stateAtSend.isAppLimited + } + if sample.sendRate != infBandwidth { + maxSendRate = max(maxSendRate, sample.sendRate) + } + inflightSample := b.totalBytesAcked - lastAckedPacketSendState.totalBytesAcked + if inflightSample > eventSample.sampleMaxInflight { + eventSample.sampleMaxInflight = inflightSample + } + } + + if !lastLostPacketSendState.isValid { + eventSample.lastPacketSendState = lastAckedPacketSendState + } else if !lastAckedPacketSendState.isValid { + eventSample.lastPacketSendState = lastLostPacketSendState + } else { + // If two packets are inflight and an alarm is armed to lose a packet and it + // wakes up late, then the first of two in flight packets could have been + // acknowledged before the wakeup, which re-evaluates loss detection, and + // could declare the later of the two lost. + if lostPackets[len(lostPackets)-1].PacketNumber > ackedPackets[len(ackedPackets)-1].PacketNumber { + eventSample.lastPacketSendState = lastLostPacketSendState + } else { + eventSample.lastPacketSendState = lastAckedPacketSendState + } + } + + isNewMaxBandwidth := eventSample.sampleMaxBandwidth > maxBandwidth + maxBandwidth = max(maxBandwidth, eventSample.sampleMaxBandwidth) + if b.limitMaxAckHeightTrackerBySendRate { + maxBandwidth = max(maxBandwidth, maxSendRate) + } + + eventSample.extraAcked = b.onAckEventEnd(min(estBandwidthUpperBound, maxBandwidth), isNewMaxBandwidth, roundTripCount) + + return *eventSample +} + +func (b *bandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber, bytesLost congestion.ByteCount) (s sendTimeState) { + b.totalBytesLost += bytesLost + if sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber); sentPacketPointer != nil { + sentPacketToSendTimeState(sentPacketPointer, &s) + } + return s +} + +func (b *bandwidthSampler) OnPacketNeutered(packetNumber congestion.PacketNumber) { + b.connectionStateMap.Remove(packetNumber, func(sentPacket connectionStateOnSentPacket) { + b.totalBytesNeutered += sentPacket.size + }) +} + +func (b *bandwidthSampler) OnAppLimited() { + b.isAppLimited = true + b.endOfAppLimitedPhase = b.lastSentPacket +} + +func (b *bandwidthSampler) RemoveObsoletePackets(leastUnacked congestion.PacketNumber) { + // A packet can become obsolete when it is removed from QuicUnackedPacketMap's + // view of inflight before it is acked or marked as lost. For example, when + // QuicSentPacketManager::RetransmitCryptoPackets retransmits a crypto packet, + // the packet is removed from QuicUnackedPacketMap's inflight, but is not + // marked as acked or lost in the BandwidthSampler. + b.connectionStateMap.RemoveUpTo(leastUnacked) +} + +func (b *bandwidthSampler) TotalBytesSent() congestion.ByteCount { + return b.totalBytesSent +} + +func (b *bandwidthSampler) TotalBytesLost() congestion.ByteCount { + return b.totalBytesLost +} + +func (b *bandwidthSampler) TotalBytesAcked() congestion.ByteCount { + return b.totalBytesAcked +} + +func (b *bandwidthSampler) TotalBytesNeutered() congestion.ByteCount { + return b.totalBytesNeutered +} + +func (b *bandwidthSampler) IsAppLimited() bool { + return b.isAppLimited +} + +func (b *bandwidthSampler) EndOfAppLimitedPhase() congestion.PacketNumber { + return b.endOfAppLimitedPhase +} + +func (b *bandwidthSampler) max_ack_height() congestion.ByteCount { + return b.maxAckHeightTracker.Get() +} + +func (b *bandwidthSampler) chooseA0Point(totalBytesAcked congestion.ByteCount, a0 *ackPoint) bool { + if b.a0Candidates.Empty() { + return false + } + + if b.a0Candidates.Len() == 1 { + *a0 = *b.a0Candidates.Front() + return true + } + + for i := 1; i < b.a0Candidates.Len(); i++ { + if b.a0Candidates.Offset(i).totalBytesAcked > totalBytesAcked { + *a0 = *b.a0Candidates.Offset(i - 1) + if i > 1 { + for j := 0; j < i-1; j++ { + b.a0Candidates.PopFront() + } + } + return true + } + } + + *a0 = *b.a0Candidates.Back() + for k := 0; k < b.a0Candidates.Len()-1; k++ { + b.a0Candidates.PopFront() + } + return true +} + +func (b *bandwidthSampler) onPacketAcknowledged(ackTime time.Time, packetNumber congestion.PacketNumber) bandwidthSample { + sample := newBandwidthSample() + b.lastAckedPacket = packetNumber + sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber) + if sentPacketPointer == nil { + return *sample + } + + // OnPacketAcknowledgedInner + b.totalBytesAcked += sentPacketPointer.size + b.totalBytesSentAtLastAckedPacket = sentPacketPointer.sendTimeState.totalBytesSent + b.lastAckedPacketSentTime = sentPacketPointer.sentTime + b.lastAckedPacketAckTime = ackTime + if b.overestimateAvoidance { + b.recentAckPoints.Update(ackTime, b.totalBytesAcked) + } + + if b.isAppLimited { + // Exit app-limited phase in two cases: + // (1) end_of_app_limited_phase_ is not initialized, i.e., so far all + // packets are sent while there are buffered packets or pending data. + // (2) The current acked packet is after the sent packet marked as the end + // of the app limit phase. + if b.endOfAppLimitedPhase == invalidPacketNumber || + packetNumber > b.endOfAppLimitedPhase { + b.isAppLimited = false + } + } + + // There might have been no packets acknowledged at the moment when the + // current packet was sent. In that case, there is no bandwidth sample to + // make. + if sentPacketPointer.lastAckedPacketSentTime.IsZero() { + return *sample + } + + // Infinite rate indicates that the sampler is supposed to discard the + // current send rate sample and use only the ack rate. + sendRate := infBandwidth + if sentPacketPointer.sentTime.After(sentPacketPointer.lastAckedPacketSentTime) { + sendRate = BandwidthFromDelta( + sentPacketPointer.sendTimeState.totalBytesSent-sentPacketPointer.totalBytesSentAtLastAckedPacket, + sentPacketPointer.sentTime.Sub(sentPacketPointer.lastAckedPacketSentTime)) + } + + var a0 ackPoint + if b.overestimateAvoidance && b.chooseA0Point(sentPacketPointer.sendTimeState.totalBytesAcked, &a0) { + } else { + a0.ackTime = sentPacketPointer.lastAckedPacketAckTime + a0.totalBytesAcked = sentPacketPointer.sendTimeState.totalBytesAcked + } + + // During the slope calculation, ensure that ack time of the current packet is + // always larger than the time of the previous packet, otherwise division by + // zero or integer underflow can occur. + if ackTime.Sub(a0.ackTime) <= 0 { + return *sample + } + + ackRate := BandwidthFromDelta(b.totalBytesAcked-a0.totalBytesAcked, ackTime.Sub(a0.ackTime)) + + sample.bandwidth = min(sendRate, ackRate) + // Note: this sample does not account for delayed acknowledgement time. This + // means that the RTT measurements here can be artificially high, especially + // on low bandwidth connections. + sample.rtt = ackTime.Sub(sentPacketPointer.sentTime) + sample.sendRate = sendRate + sentPacketToSendTimeState(sentPacketPointer, &sample.stateAtSend) + + return *sample +} + +func (b *bandwidthSampler) onAckEventEnd( + bandwidthEstimate Bandwidth, + isNewMaxBandwidth bool, + roundTripCount roundTripCount, +) congestion.ByteCount { + newlyAckedBytes := b.totalBytesAcked - b.totalBytesAckedAfterLastAckEvent + if newlyAckedBytes == 0 { + return 0 + } + b.totalBytesAckedAfterLastAckEvent = b.totalBytesAcked + extraAcked := b.maxAckHeightTracker.Update( + bandwidthEstimate, + isNewMaxBandwidth, + roundTripCount, + b.lastSentPacket, + b.lastAckedPacket, + b.lastAckedPacketAckTime, + newlyAckedBytes) + // If |extra_acked| is zero, i.e. this ack event marks the start of a new ack + // aggregation epoch, save LessRecentPoint, which is the last ack point of the + // previous epoch, as a A0 candidate. + if b.overestimateAvoidance && extraAcked == 0 { + b.a0Candidates.PushBack(*b.recentAckPoints.LessRecentPoint()) + } + return extraAcked +} + +func sentPacketToSendTimeState(sentPacket *connectionStateOnSentPacket, sendTimeState *sendTimeState) { + *sendTimeState = sentPacket.sendTimeState + sendTimeState.isValid = true +} + +// BytesFromBandwidthAndTimeDelta calculates the bytes +// from a bandwidth(bits per second) and a time delta +func bytesFromBandwidthAndTimeDelta(bandwidth Bandwidth, delta time.Duration) congestion.ByteCount { + return (congestion.ByteCount(bandwidth) * congestion.ByteCount(delta)) / + (congestion.ByteCount(time.Second) * 8) +} + +func timeDeltaFromBytesAndBandwidth(bytes congestion.ByteCount, bandwidth Bandwidth) time.Duration { + return time.Duration(bytes*8) * time.Second / time.Duration(bandwidth) +} diff --git a/transport/hysteria2/core/internal/congestion/bbr/bbr_sender.go b/transport/hysteria2/core/internal/congestion/bbr/bbr_sender.go new file mode 100644 index 00000000..4ec11324 --- /dev/null +++ b/transport/hysteria2/core/internal/congestion/bbr/bbr_sender.go @@ -0,0 +1,944 @@ +package bbr + +import ( + "fmt" + "math/rand" + "net" + "time" + + "github.com/metacubex/quic-go/congestion" + + "github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion/common" +) + +// BbrSender implements BBR congestion control algorithm. BBR aims to estimate +// the current available Bottleneck Bandwidth and RTT (hence the name), and +// regulates the pacing rate and the size of the congestion window based on +// those signals. +// +// BBR relies on pacing in order to function properly. Do not use BBR when +// pacing is disabled. +// + +const ( + minBps = 65536 // 64 kbps + + invalidPacketNumber = -1 + initialCongestionWindowPackets = 32 + + // Constants based on TCP defaults. + // The minimum CWND to ensure delayed acks don't reduce bandwidth measurements. + // Does not inflate the pacing rate. + defaultMinimumCongestionWindow = 4 * congestion.ByteCount(congestion.InitialPacketSizeIPv4) + + // The gain used for the STARTUP, equal to 2/ln(2). + defaultHighGain = 2.885 + // The newly derived gain for STARTUP, equal to 4 * ln(2) + derivedHighGain = 2.773 + // The newly derived CWND gain for STARTUP, 2. + derivedHighCWNDGain = 2.0 +) + +// The cycle of gains used during the PROBE_BW stage. +var pacingGain = [...]float64{1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0} + +const ( + // The length of the gain cycle. + gainCycleLength = len(pacingGain) + // The size of the bandwidth filter window, in round-trips. + bandwidthWindowSize = gainCycleLength + 2 + + // The time after which the current min_rtt value expires. + minRttExpiry = 10 * time.Second + // The minimum time the connection can spend in PROBE_RTT mode. + probeRttTime = 200 * time.Millisecond + // If the bandwidth does not increase by the factor of |kStartupGrowthTarget| + // within |kRoundTripsWithoutGrowthBeforeExitingStartup| rounds, the connection + // will exit the STARTUP mode. + startupGrowthTarget = 1.25 + roundTripsWithoutGrowthBeforeExitingStartup = int64(3) + + // Flag. + defaultStartupFullLossCount = 8 + quicBbr2DefaultLossThreshold = 0.02 + maxBbrBurstPackets = 3 +) + +type bbrMode int + +const ( + // Startup phase of the connection. + bbrModeStartup = iota + // After achieving the highest possible bandwidth during the startup, lower + // the pacing rate in order to drain the queue. + bbrModeDrain + // Cruising mode. + bbrModeProbeBw + // Temporarily slow down sending in order to empty the buffer and measure + // the real minimum RTT. + bbrModeProbeRtt +) + +// Indicates how the congestion control limits the amount of bytes in flight. +type bbrRecoveryState int + +const ( + // Do not limit. + bbrRecoveryStateNotInRecovery = iota + // Allow an extra outstanding byte for each byte acknowledged. + bbrRecoveryStateConservation + // Allow two extra outstanding bytes for each byte acknowledged (slow + // start). + bbrRecoveryStateGrowth +) + +type bbrSender struct { + rttStats congestion.RTTStatsProvider + clock Clock + pacer *common.Pacer + + mode bbrMode + + // Bandwidth sampler provides BBR with the bandwidth measurements at + // individual points. + sampler *bandwidthSampler + + // The number of the round trips that have occurred during the connection. + roundTripCount roundTripCount + + // The packet number of the most recently sent packet. + lastSentPacket congestion.PacketNumber + // Acknowledgement of any packet after |current_round_trip_end_| will cause + // the round trip counter to advance. + currentRoundTripEnd congestion.PacketNumber + + // Number of congestion events with some losses, in the current round. + numLossEventsInRound uint64 + + // Number of total bytes lost in the current round. + bytesLostInRound congestion.ByteCount + + // The filter that tracks the maximum bandwidth over the multiple recent + // round-trips. + maxBandwidth *WindowedFilter[Bandwidth, roundTripCount] + + // Minimum RTT estimate. Automatically expires within 10 seconds (and + // triggers PROBE_RTT mode) if no new value is sampled during that period. + minRtt time.Duration + // The time at which the current value of |min_rtt_| was assigned. + minRttTimestamp time.Time + + // The maximum allowed number of bytes in flight. + congestionWindow congestion.ByteCount + + // The initial value of the |congestion_window_|. + initialCongestionWindow congestion.ByteCount + + // The largest value the |congestion_window_| can achieve. + maxCongestionWindow congestion.ByteCount + + // The smallest value the |congestion_window_| can achieve. + minCongestionWindow congestion.ByteCount + + // The pacing gain applied during the STARTUP phase. + highGain float64 + + // The CWND gain applied during the STARTUP phase. + highCwndGain float64 + + // The pacing gain applied during the DRAIN phase. + drainGain float64 + + // The current pacing rate of the connection. + pacingRate Bandwidth + + // The gain currently applied to the pacing rate. + pacingGain float64 + // The gain currently applied to the congestion window. + congestionWindowGain float64 + + // The gain used for the congestion window during PROBE_BW. Latched from + // quic_bbr_cwnd_gain flag. + congestionWindowGainConstant float64 + // The number of RTTs to stay in STARTUP mode. Defaults to 3. + numStartupRtts int64 + + // Number of round-trips in PROBE_BW mode, used for determining the current + // pacing gain cycle. + cycleCurrentOffset int + // The time at which the last pacing gain cycle was started. + lastCycleStart time.Time + + // Indicates whether the connection has reached the full bandwidth mode. + isAtFullBandwidth bool + // Number of rounds during which there was no significant bandwidth increase. + roundsWithoutBandwidthGain int64 + // The bandwidth compared to which the increase is measured. + bandwidthAtLastRound Bandwidth + + // Set to true upon exiting quiescence. + exitingQuiescence bool + + // Time at which PROBE_RTT has to be exited. Setting it to zero indicates + // that the time is yet unknown as the number of packets in flight has not + // reached the required value. + exitProbeRttAt time.Time + // Indicates whether a round-trip has passed since PROBE_RTT became active. + probeRttRoundPassed bool + + // Indicates whether the most recent bandwidth sample was marked as + // app-limited. + lastSampleIsAppLimited bool + // Indicates whether any non app-limited samples have been recorded. + hasNoAppLimitedSample bool + + // Current state of recovery. + recoveryState bbrRecoveryState + // Receiving acknowledgement of a packet after |end_recovery_at_| will cause + // BBR to exit the recovery mode. A value above zero indicates at least one + // loss has been detected, so it must not be set back to zero. + endRecoveryAt congestion.PacketNumber + // A window used to limit the number of bytes in flight during loss recovery. + recoveryWindow congestion.ByteCount + // If true, consider all samples in recovery app-limited. + isAppLimitedRecovery bool // not used + + // When true, pace at 1.5x and disable packet conservation in STARTUP. + slowerStartup bool // not used + // When true, disables packet conservation in STARTUP. + rateBasedStartup bool // not used + + // When true, add the most recent ack aggregation measurement during STARTUP. + enableAckAggregationDuringStartup bool + // When true, expire the windowed ack aggregation values in STARTUP when + // bandwidth increases more than 25%. + expireAckAggregationInStartup bool + + // If true, will not exit low gain mode until bytes_in_flight drops below BDP + // or it's time for high gain mode. + drainToTarget bool + + // If true, slow down pacing rate in STARTUP when overshooting is detected. + detectOvershooting bool + // Bytes lost while detect_overshooting_ is true. + bytesLostWhileDetectingOvershooting congestion.ByteCount + // Slow down pacing rate if + // bytes_lost_while_detecting_overshooting_ * + // bytes_lost_multiplier_while_detecting_overshooting_ > IW. + bytesLostMultiplierWhileDetectingOvershooting uint8 + // When overshooting is detected, do not drop pacing_rate_ below this value / + // min_rtt. + cwndToCalculateMinPacingRate congestion.ByteCount + + // Max congestion window when adjusting network parameters. + maxCongestionWindowWithNetworkParametersAdjusted congestion.ByteCount // not used + + // Params. + maxDatagramSize congestion.ByteCount + // Recorded on packet sent. equivalent |unacked_packets_->bytes_in_flight()| + bytesInFlight congestion.ByteCount +} + +var _ congestion.CongestionControl = &bbrSender{} + +func NewBbrSender( + clock Clock, + initialMaxDatagramSize congestion.ByteCount, +) *bbrSender { + return newBbrSender( + clock, + initialMaxDatagramSize, + initialCongestionWindowPackets*initialMaxDatagramSize, + congestion.MaxCongestionWindowPackets*initialMaxDatagramSize, + ) +} + +func newBbrSender( + clock Clock, + initialMaxDatagramSize, + initialCongestionWindow, + initialMaxCongestionWindow congestion.ByteCount, +) *bbrSender { + b := &bbrSender{ + clock: clock, + mode: bbrModeStartup, + sampler: newBandwidthSampler(roundTripCount(bandwidthWindowSize)), + lastSentPacket: invalidPacketNumber, + currentRoundTripEnd: invalidPacketNumber, + maxBandwidth: NewWindowedFilter(roundTripCount(bandwidthWindowSize), MaxFilter[Bandwidth]), + congestionWindow: initialCongestionWindow, + initialCongestionWindow: initialCongestionWindow, + maxCongestionWindow: initialMaxCongestionWindow, + minCongestionWindow: defaultMinimumCongestionWindow, + highGain: defaultHighGain, + highCwndGain: defaultHighGain, + drainGain: 1.0 / defaultHighGain, + pacingGain: 1.0, + congestionWindowGain: 1.0, + congestionWindowGainConstant: 2.0, + numStartupRtts: roundTripsWithoutGrowthBeforeExitingStartup, + recoveryState: bbrRecoveryStateNotInRecovery, + endRecoveryAt: invalidPacketNumber, + recoveryWindow: initialMaxCongestionWindow, + bytesLostMultiplierWhileDetectingOvershooting: 2, + cwndToCalculateMinPacingRate: initialCongestionWindow, + maxCongestionWindowWithNetworkParametersAdjusted: initialMaxCongestionWindow, + maxDatagramSize: initialMaxDatagramSize, + } + b.pacer = common.NewPacer(b.bandwidthForPacer) + + /* + if b.tracer != nil { + b.lastState = logging.CongestionStateStartup + b.tracer.UpdatedCongestionState(logging.CongestionStateStartup) + } + */ + + b.enterStartupMode(b.clock.Now()) + b.setHighCwndGain(derivedHighCWNDGain) + + return b +} + +func (b *bbrSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { + b.rttStats = provider +} + +// TimeUntilSend implements the SendAlgorithm interface. +func (b *bbrSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { + return b.pacer.TimeUntilSend() +} + +// HasPacingBudget implements the SendAlgorithm interface. +func (b *bbrSender) HasPacingBudget(now time.Time) bool { + return b.pacer.Budget(now) >= b.maxDatagramSize +} + +// OnPacketSent implements the SendAlgorithm interface. +func (b *bbrSender) OnPacketSent( + sentTime time.Time, + bytesInFlight congestion.ByteCount, + packetNumber congestion.PacketNumber, + bytes congestion.ByteCount, + isRetransmittable bool, +) { + b.pacer.SentPacket(sentTime, bytes) + + b.lastSentPacket = packetNumber + b.bytesInFlight = bytesInFlight + + if bytesInFlight == 0 { + b.exitingQuiescence = true + } + + b.sampler.OnPacketSent(sentTime, packetNumber, bytes, bytesInFlight, isRetransmittable) +} + +// CanSend implements the SendAlgorithm interface. +func (b *bbrSender) CanSend(bytesInFlight congestion.ByteCount) bool { + return bytesInFlight < b.GetCongestionWindow() +} + +// MaybeExitSlowStart implements the SendAlgorithm interface. +func (b *bbrSender) MaybeExitSlowStart() { + // Do nothing +} + +// OnPacketAcked implements the SendAlgorithm interface. +func (b *bbrSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes, priorInFlight congestion.ByteCount, eventTime time.Time) { + // Do nothing. +} + +// OnPacketLost implements the SendAlgorithm interface. +func (b *bbrSender) OnPacketLost(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { + // Do nothing. +} + +// OnRetransmissionTimeout implements the SendAlgorithm interface. +func (b *bbrSender) OnRetransmissionTimeout(packetsRetransmitted bool) { + // Do nothing. +} + +// SetMaxDatagramSize implements the SendAlgorithm interface. +func (b *bbrSender) SetMaxDatagramSize(s congestion.ByteCount) { + if s < b.maxDatagramSize { + panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", b.maxDatagramSize, s)) + } + cwndIsMinCwnd := b.congestionWindow == b.minCongestionWindow + b.maxDatagramSize = s + if cwndIsMinCwnd { + b.congestionWindow = b.minCongestionWindow + } + b.pacer.SetMaxDatagramSize(s) +} + +// InSlowStart implements the SendAlgorithmWithDebugInfos interface. +func (b *bbrSender) InSlowStart() bool { + return b.mode == bbrModeStartup +} + +// InRecovery implements the SendAlgorithmWithDebugInfos interface. +func (b *bbrSender) InRecovery() bool { + return b.recoveryState != bbrRecoveryStateNotInRecovery +} + +// GetCongestionWindow implements the SendAlgorithmWithDebugInfos interface. +func (b *bbrSender) GetCongestionWindow() congestion.ByteCount { + if b.mode == bbrModeProbeRtt { + return b.probeRttCongestionWindow() + } + + if b.InRecovery() { + return min(b.congestionWindow, b.recoveryWindow) + } + + return b.congestionWindow +} + +func (b *bbrSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { + // Do nothing. +} + +func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { + totalBytesAckedBefore := b.sampler.TotalBytesAcked() + totalBytesLostBefore := b.sampler.TotalBytesLost() + + var isRoundStart, minRttExpired bool + var excessAcked, bytesLost congestion.ByteCount + + // The send state of the largest packet in acked_packets, unless it is + // empty. If acked_packets is empty, it's the send state of the largest + // packet in lost_packets. + var lastPacketSendState sendTimeState + + b.maybeApplimited(priorInFlight) + + // Update bytesInFlight + b.bytesInFlight = priorInFlight + for _, p := range ackedPackets { + b.bytesInFlight -= p.BytesAcked + } + for _, p := range lostPackets { + b.bytesInFlight -= p.BytesLost + } + + if len(ackedPackets) != 0 { + lastAckedPacket := ackedPackets[len(ackedPackets)-1].PacketNumber + isRoundStart = b.updateRoundTripCounter(lastAckedPacket) + b.updateRecoveryState(lastAckedPacket, len(lostPackets) != 0, isRoundStart) + } + + sample := b.sampler.OnCongestionEvent(eventTime, + ackedPackets, lostPackets, b.maxBandwidth.GetBest(), infBandwidth, b.roundTripCount) + if sample.lastPacketSendState.isValid { + b.lastSampleIsAppLimited = sample.lastPacketSendState.isAppLimited + b.hasNoAppLimitedSample = b.hasNoAppLimitedSample || !b.lastSampleIsAppLimited + } + // Avoid updating |max_bandwidth_| if a) this is a loss-only event, or b) all + // packets in |acked_packets| did not generate valid samples. (e.g. ack of + // ack-only packets). In both cases, sampler_.total_bytes_acked() will not + // change. + if totalBytesAckedBefore != b.sampler.TotalBytesAcked() { + if !sample.sampleIsAppLimited || sample.sampleMaxBandwidth > b.maxBandwidth.GetBest() { + b.maxBandwidth.Update(sample.sampleMaxBandwidth, b.roundTripCount) + } + } + + if sample.sampleRtt != infRTT { + minRttExpired = b.maybeUpdateMinRtt(eventTime, sample.sampleRtt) + } + bytesLost = b.sampler.TotalBytesLost() - totalBytesLostBefore + + excessAcked = sample.extraAcked + lastPacketSendState = sample.lastPacketSendState + + if len(lostPackets) != 0 { + b.numLossEventsInRound++ + b.bytesLostInRound += bytesLost + } + + // Handle logic specific to PROBE_BW mode. + if b.mode == bbrModeProbeBw { + b.updateGainCyclePhase(eventTime, priorInFlight, len(lostPackets) != 0) + } + + // Handle logic specific to STARTUP and DRAIN modes. + if isRoundStart && !b.isAtFullBandwidth { + b.checkIfFullBandwidthReached(&lastPacketSendState) + } + + b.maybeExitStartupOrDrain(eventTime) + + // Handle logic specific to PROBE_RTT. + b.maybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired) + + // Calculate number of packets acked and lost. + bytesAcked := b.sampler.TotalBytesAcked() - totalBytesAckedBefore + + // After the model is updated, recalculate the pacing rate and congestion + // window. + b.calculatePacingRate(bytesLost) + b.calculateCongestionWindow(bytesAcked, excessAcked) + b.calculateRecoveryWindow(bytesAcked, bytesLost) + + // Cleanup internal state. + // This is where we clean up obsolete (acked or lost) packets from the bandwidth sampler. + // The "least unacked" should actually be FirstOutstanding, but since we are not passing + // that through OnCongestionEventEx, we will only do an estimate using acked/lost packets + // for now. Because of fast retransmission, they should differ by no more than 2 packets. + // (this is controlled by packetThreshold in quic-go's sentPacketHandler) + var leastUnacked congestion.PacketNumber + if len(ackedPackets) != 0 { + leastUnacked = ackedPackets[len(ackedPackets)-1].PacketNumber - 2 + } else { + leastUnacked = lostPackets[len(lostPackets)-1].PacketNumber + 1 + } + b.sampler.RemoveObsoletePackets(leastUnacked) + + if isRoundStart { + b.numLossEventsInRound = 0 + b.bytesLostInRound = 0 + } +} + +func (b *bbrSender) PacingRate() Bandwidth { + if b.pacingRate == 0 { + return Bandwidth(b.highGain * float64( + BandwidthFromDelta(b.initialCongestionWindow, b.getMinRtt()))) + } + + return b.pacingRate +} + +func (b *bbrSender) hasGoodBandwidthEstimateForResumption() bool { + return b.hasNonAppLimitedSample() +} + +func (b *bbrSender) hasNonAppLimitedSample() bool { + return b.hasNoAppLimitedSample +} + +// Sets the pacing gain used in STARTUP. Must be greater than 1. +func (b *bbrSender) setHighGain(highGain float64) { + b.highGain = highGain + if b.mode == bbrModeStartup { + b.pacingGain = highGain + } +} + +// Sets the CWND gain used in STARTUP. Must be greater than 1. +func (b *bbrSender) setHighCwndGain(highCwndGain float64) { + b.highCwndGain = highCwndGain + if b.mode == bbrModeStartup { + b.congestionWindowGain = highCwndGain + } +} + +// Sets the gain used in DRAIN. Must be less than 1. +func (b *bbrSender) setDrainGain(drainGain float64) { + b.drainGain = drainGain +} + +// What's the current estimated bandwidth in bytes per second. +func (b *bbrSender) bandwidthEstimate() Bandwidth { + return b.maxBandwidth.GetBest() +} + +func (b *bbrSender) bandwidthForPacer() congestion.ByteCount { + bps := congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond)) + if bps < minBps { + // We need to make sure that the bandwidth value for pacer is never zero, + // otherwise it will go into an edge case where HasPacingBudget = false + // but TimeUntilSend is before, causing the quic-go send loop to go crazy and get stuck. + return minBps + } + return bps +} + +// Returns the current estimate of the RTT of the connection. Outside of the +// edge cases, this is minimum RTT. +func (b *bbrSender) getMinRtt() time.Duration { + if b.minRtt != 0 { + return b.minRtt + } + // min_rtt could be available if the handshake packet gets neutered then + // gets acknowledged. This could only happen for QUIC crypto where we do not + // drop keys. + minRtt := b.rttStats.MinRTT() + if minRtt == 0 { + return 100 * time.Millisecond + } else { + return minRtt + } +} + +// Computes the target congestion window using the specified gain. +func (b *bbrSender) getTargetCongestionWindow(gain float64) congestion.ByteCount { + bdp := bdpFromRttAndBandwidth(b.getMinRtt(), b.bandwidthEstimate()) + congestionWindow := congestion.ByteCount(gain * float64(bdp)) + + // BDP estimate will be zero if no bandwidth samples are available yet. + if congestionWindow == 0 { + congestionWindow = congestion.ByteCount(gain * float64(b.initialCongestionWindow)) + } + + return max(congestionWindow, b.minCongestionWindow) +} + +// The target congestion window during PROBE_RTT. +func (b *bbrSender) probeRttCongestionWindow() congestion.ByteCount { + return b.minCongestionWindow +} + +func (b *bbrSender) maybeUpdateMinRtt(now time.Time, sampleMinRtt time.Duration) bool { + // Do not expire min_rtt if none was ever available. + minRttExpired := b.minRtt != 0 && now.After(b.minRttTimestamp.Add(minRttExpiry)) + if minRttExpired || sampleMinRtt < b.minRtt || b.minRtt == 0 { + b.minRtt = sampleMinRtt + b.minRttTimestamp = now + } + + return minRttExpired +} + +// Enters the STARTUP mode. +func (b *bbrSender) enterStartupMode(now time.Time) { + b.mode = bbrModeStartup + // b.maybeTraceStateChange(logging.CongestionStateStartup) + b.pacingGain = b.highGain + b.congestionWindowGain = b.highCwndGain +} + +// Enters the PROBE_BW mode. +func (b *bbrSender) enterProbeBandwidthMode(now time.Time) { + b.mode = bbrModeProbeBw + // b.maybeTraceStateChange(logging.CongestionStateProbeBw) + b.congestionWindowGain = b.congestionWindowGainConstant + + // Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is + // excluded because in that case increased gain and decreased gain would not + // follow each other. + b.cycleCurrentOffset = int(rand.Int31n(congestion.PacketsPerConnectionID)) % (gainCycleLength - 1) + if b.cycleCurrentOffset >= 1 { + b.cycleCurrentOffset += 1 + } + + b.lastCycleStart = now + b.pacingGain = pacingGain[b.cycleCurrentOffset] +} + +// Updates the round-trip counter if a round-trip has passed. Returns true if +// the counter has been advanced. +func (b *bbrSender) updateRoundTripCounter(lastAckedPacket congestion.PacketNumber) bool { + if b.currentRoundTripEnd == invalidPacketNumber || lastAckedPacket > b.currentRoundTripEnd { + b.roundTripCount++ + b.currentRoundTripEnd = b.lastSentPacket + return true + } + return false +} + +// Updates the current gain used in PROBE_BW mode. +func (b *bbrSender) updateGainCyclePhase(now time.Time, priorInFlight congestion.ByteCount, hasLosses bool) { + // In most cases, the cycle is advanced after an RTT passes. + shouldAdvanceGainCycling := now.After(b.lastCycleStart.Add(b.getMinRtt())) + // If the pacing gain is above 1.0, the connection is trying to probe the + // bandwidth by increasing the number of bytes in flight to at least + // pacing_gain * BDP. Make sure that it actually reaches the target, as long + // as there are no losses suggesting that the buffers are not able to hold + // that much. + if b.pacingGain > 1.0 && !hasLosses && priorInFlight < b.getTargetCongestionWindow(b.pacingGain) { + shouldAdvanceGainCycling = false + } + + // If pacing gain is below 1.0, the connection is trying to drain the extra + // queue which could have been incurred by probing prior to it. If the number + // of bytes in flight falls down to the estimated BDP value earlier, conclude + // that the queue has been successfully drained and exit this cycle early. + if b.pacingGain < 1.0 && b.bytesInFlight <= b.getTargetCongestionWindow(1) { + shouldAdvanceGainCycling = true + } + + if shouldAdvanceGainCycling { + b.cycleCurrentOffset = (b.cycleCurrentOffset + 1) % gainCycleLength + b.lastCycleStart = now + // Stay in low gain mode until the target BDP is hit. + // Low gain mode will be exited immediately when the target BDP is achieved. + if b.drainToTarget && b.pacingGain < 1 && + pacingGain[b.cycleCurrentOffset] == 1 && + b.bytesInFlight > b.getTargetCongestionWindow(1) { + return + } + b.pacingGain = pacingGain[b.cycleCurrentOffset] + } +} + +// Tracks for how many round-trips the bandwidth has not increased +// significantly. +func (b *bbrSender) checkIfFullBandwidthReached(lastPacketSendState *sendTimeState) { + if b.lastSampleIsAppLimited { + return + } + + target := Bandwidth(float64(b.bandwidthAtLastRound) * startupGrowthTarget) + if b.bandwidthEstimate() >= target { + b.bandwidthAtLastRound = b.bandwidthEstimate() + b.roundsWithoutBandwidthGain = 0 + if b.expireAckAggregationInStartup { + // Expire old excess delivery measurements now that bandwidth increased. + b.sampler.ResetMaxAckHeightTracker(0, b.roundTripCount) + } + return + } + + b.roundsWithoutBandwidthGain++ + if b.roundsWithoutBandwidthGain >= b.numStartupRtts || + b.shouldExitStartupDueToLoss(lastPacketSendState) { + b.isAtFullBandwidth = true + } +} + +func (b *bbrSender) maybeApplimited(bytesInFlight congestion.ByteCount) { + congestionWindow := b.GetCongestionWindow() + if bytesInFlight >= congestionWindow { + return + } + availableBytes := congestionWindow - bytesInFlight + drainLimited := b.mode == bbrModeDrain && bytesInFlight > congestionWindow/2 + if !drainLimited || availableBytes > maxBbrBurstPackets*b.maxDatagramSize { + b.sampler.OnAppLimited() + } +} + +// Transitions from STARTUP to DRAIN and from DRAIN to PROBE_BW if +// appropriate. +func (b *bbrSender) maybeExitStartupOrDrain(now time.Time) { + if b.mode == bbrModeStartup && b.isAtFullBandwidth { + b.mode = bbrModeDrain + // b.maybeTraceStateChange(logging.CongestionStateDrain) + b.pacingGain = b.drainGain + b.congestionWindowGain = b.highCwndGain + } + if b.mode == bbrModeDrain && b.bytesInFlight <= b.getTargetCongestionWindow(1) { + b.enterProbeBandwidthMode(now) + } +} + +// Decides whether to enter or exit PROBE_RTT. +func (b *bbrSender) maybeEnterOrExitProbeRtt(now time.Time, isRoundStart, minRttExpired bool) { + if minRttExpired && !b.exitingQuiescence && b.mode != bbrModeProbeRtt { + b.mode = bbrModeProbeRtt + // b.maybeTraceStateChange(logging.CongestionStateProbRtt) + b.pacingGain = 1.0 + // Do not decide on the time to exit PROBE_RTT until the |bytes_in_flight| + // is at the target small value. + b.exitProbeRttAt = time.Time{} + } + + if b.mode == bbrModeProbeRtt { + b.sampler.OnAppLimited() + // b.maybeTraceStateChange(logging.CongestionStateApplicationLimited) + + if b.exitProbeRttAt.IsZero() { + // If the window has reached the appropriate size, schedule exiting + // PROBE_RTT. The CWND during PROBE_RTT is kMinimumCongestionWindow, but + // we allow an extra packet since QUIC checks CWND before sending a + // packet. + if b.bytesInFlight < b.probeRttCongestionWindow()+congestion.MaxPacketBufferSize { + b.exitProbeRttAt = now.Add(probeRttTime) + b.probeRttRoundPassed = false + } + } else { + if isRoundStart { + b.probeRttRoundPassed = true + } + if now.Sub(b.exitProbeRttAt) >= 0 && b.probeRttRoundPassed { + b.minRttTimestamp = now + if !b.isAtFullBandwidth { + b.enterStartupMode(now) + } else { + b.enterProbeBandwidthMode(now) + } + } + } + } + + b.exitingQuiescence = false +} + +// Determines whether BBR needs to enter, exit or advance state of the +// recovery. +func (b *bbrSender) updateRecoveryState(lastAckedPacket congestion.PacketNumber, hasLosses, isRoundStart bool) { + // Disable recovery in startup, if loss-based exit is enabled. + if !b.isAtFullBandwidth { + return + } + + // Exit recovery when there are no losses for a round. + if hasLosses { + b.endRecoveryAt = b.lastSentPacket + } + + switch b.recoveryState { + case bbrRecoveryStateNotInRecovery: + if hasLosses { + b.recoveryState = bbrRecoveryStateConservation + // This will cause the |recovery_window_| to be set to the correct + // value in CalculateRecoveryWindow(). + b.recoveryWindow = 0 + // Since the conservation phase is meant to be lasting for a whole + // round, extend the current round as if it were started right now. + b.currentRoundTripEnd = b.lastSentPacket + } + case bbrRecoveryStateConservation: + if isRoundStart { + b.recoveryState = bbrRecoveryStateGrowth + } + fallthrough + case bbrRecoveryStateGrowth: + // Exit recovery if appropriate. + if !hasLosses && lastAckedPacket > b.endRecoveryAt { + b.recoveryState = bbrRecoveryStateNotInRecovery + } + } +} + +// Determines the appropriate pacing rate for the connection. +func (b *bbrSender) calculatePacingRate(bytesLost congestion.ByteCount) { + if b.bandwidthEstimate() == 0 { + return + } + + targetRate := Bandwidth(b.pacingGain * float64(b.bandwidthEstimate())) + if b.isAtFullBandwidth { + b.pacingRate = targetRate + return + } + + // Pace at the rate of initial_window / RTT as soon as RTT measurements are + // available. + if b.pacingRate == 0 && b.rttStats.MinRTT() != 0 { + b.pacingRate = BandwidthFromDelta(b.initialCongestionWindow, b.rttStats.MinRTT()) + return + } + + if b.detectOvershooting { + b.bytesLostWhileDetectingOvershooting += bytesLost + // Check for overshooting with network parameters adjusted when pacing rate + // > target_rate and loss has been detected. + if b.pacingRate > targetRate && b.bytesLostWhileDetectingOvershooting > 0 { + if b.hasNoAppLimitedSample || + b.bytesLostWhileDetectingOvershooting*congestion.ByteCount(b.bytesLostMultiplierWhileDetectingOvershooting) > b.initialCongestionWindow { + // We are fairly sure overshoot happens if 1) there is at least one + // non app-limited bw sample or 2) half of IW gets lost. Slow pacing + // rate. + b.pacingRate = max(targetRate, BandwidthFromDelta(b.cwndToCalculateMinPacingRate, b.rttStats.MinRTT())) + b.bytesLostWhileDetectingOvershooting = 0 + b.detectOvershooting = false + } + } + } + + // Do not decrease the pacing rate during startup. + b.pacingRate = max(b.pacingRate, targetRate) +} + +// Determines the appropriate congestion window for the connection. +func (b *bbrSender) calculateCongestionWindow(bytesAcked, excessAcked congestion.ByteCount) { + if b.mode == bbrModeProbeRtt { + return + } + + targetWindow := b.getTargetCongestionWindow(b.congestionWindowGain) + if b.isAtFullBandwidth { + // Add the max recently measured ack aggregation to CWND. + targetWindow += b.sampler.MaxAckHeight() + } else if b.enableAckAggregationDuringStartup { + // Add the most recent excess acked. Because CWND never decreases in + // STARTUP, this will automatically create a very localized max filter. + targetWindow += excessAcked + } + + // Instead of immediately setting the target CWND as the new one, BBR grows + // the CWND towards |target_window| by only increasing it |bytes_acked| at a + // time. + if b.isAtFullBandwidth { + b.congestionWindow = min(targetWindow, b.congestionWindow+bytesAcked) + } else if b.congestionWindow < targetWindow || + b.sampler.TotalBytesAcked() < b.initialCongestionWindow { + // If the connection is not yet out of startup phase, do not decrease the + // window. + b.congestionWindow += bytesAcked + } + + // Enforce the limits on the congestion window. + b.congestionWindow = max(b.congestionWindow, b.minCongestionWindow) + b.congestionWindow = min(b.congestionWindow, b.maxCongestionWindow) +} + +// Determines the appropriate window that constrains the in-flight during recovery. +func (b *bbrSender) calculateRecoveryWindow(bytesAcked, bytesLost congestion.ByteCount) { + if b.recoveryState == bbrRecoveryStateNotInRecovery { + return + } + + // Set up the initial recovery window. + if b.recoveryWindow == 0 { + b.recoveryWindow = b.bytesInFlight + bytesAcked + b.recoveryWindow = max(b.minCongestionWindow, b.recoveryWindow) + return + } + + // Remove losses from the recovery window, while accounting for a potential + // integer underflow. + if b.recoveryWindow >= bytesLost { + b.recoveryWindow = b.recoveryWindow - bytesLost + } else { + b.recoveryWindow = b.maxDatagramSize + } + + // In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH, + // release additional |bytes_acked| to achieve a slow-start-like behavior. + if b.recoveryState == bbrRecoveryStateGrowth { + b.recoveryWindow += bytesAcked + } + + // Always allow sending at least |bytes_acked| in response. + b.recoveryWindow = max(b.recoveryWindow, b.bytesInFlight+bytesAcked) + b.recoveryWindow = max(b.minCongestionWindow, b.recoveryWindow) +} + +// Return whether we should exit STARTUP due to excessive loss. +func (b *bbrSender) shouldExitStartupDueToLoss(lastPacketSendState *sendTimeState) bool { + if b.numLossEventsInRound < defaultStartupFullLossCount || !lastPacketSendState.isValid { + return false + } + + inflightAtSend := lastPacketSendState.bytesInFlight + + if inflightAtSend > 0 && b.bytesLostInRound > 0 { + if b.bytesLostInRound > congestion.ByteCount(float64(inflightAtSend)*quicBbr2DefaultLossThreshold) { + return true + } + return false + } + return false +} + +func bdpFromRttAndBandwidth(rtt time.Duration, bandwidth Bandwidth) congestion.ByteCount { + return congestion.ByteCount(rtt) * congestion.ByteCount(bandwidth) / congestion.ByteCount(BytesPerSecond) / congestion.ByteCount(time.Second) +} + +func GetInitialPacketSize(addr net.Addr) congestion.ByteCount { + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := addr.(*net.UDPAddr); ok { + if udpAddr.IP.To4() != nil { + return congestion.InitialPacketSizeIPv4 + } else { + return congestion.InitialPacketSizeIPv6 + } + } else { + return congestion.MinInitialPacketSize + } +} diff --git a/transport/hysteria2/core/internal/congestion/bbr/clock.go b/transport/hysteria2/core/internal/congestion/bbr/clock.go new file mode 100644 index 00000000..a66344fb --- /dev/null +++ b/transport/hysteria2/core/internal/congestion/bbr/clock.go @@ -0,0 +1,18 @@ +package bbr + +import "time" + +// A Clock returns the current time +type Clock interface { + Now() time.Time +} + +// DefaultClock implements the Clock interface using the Go stdlib clock. +type DefaultClock struct{} + +var _ Clock = DefaultClock{} + +// Now gets the current time +func (DefaultClock) Now() time.Time { + return time.Now() +} diff --git a/transport/hysteria2/core/internal/congestion/bbr/packet_number_indexed_queue.go b/transport/hysteria2/core/internal/congestion/bbr/packet_number_indexed_queue.go new file mode 100644 index 00000000..55e9c2bd --- /dev/null +++ b/transport/hysteria2/core/internal/congestion/bbr/packet_number_indexed_queue.go @@ -0,0 +1,199 @@ +package bbr + +import ( + "github.com/metacubex/quic-go/congestion" +) + +// packetNumberIndexedQueue is a queue of mostly continuous numbered entries +// which supports the following operations: +// - adding elements to the end of the queue, or at some point past the end +// - removing elements in any order +// - retrieving elements +// If all elements are inserted in order, all of the operations above are +// amortized O(1) time. +// +// Internally, the data structure is a deque where each element is marked as +// present or not. The deque starts at the lowest present index. Whenever an +// element is removed, it's marked as not present, and the front of the deque is +// cleared of elements that are not present. +// +// The tail of the queue is not cleared due to the assumption of entries being +// inserted in order, though removing all elements of the queue will return it +// to its initial state. +// +// Note that this data structure is inherently hazardous, since an addition of +// just two entries will cause it to consume all of the memory available. +// Because of that, it is not a general-purpose container and should not be used +// as one. + +type entryWrapper[T any] struct { + present bool + entry T +} + +type packetNumberIndexedQueue[T any] struct { + entries RingBuffer[entryWrapper[T]] + numberOfPresentEntries int + firstPacket congestion.PacketNumber +} + +func newPacketNumberIndexedQueue[T any](size int) *packetNumberIndexedQueue[T] { + q := &packetNumberIndexedQueue[T]{ + firstPacket: invalidPacketNumber, + } + + q.entries.Init(size) + + return q +} + +// Emplace inserts data associated |packet_number| into (or past) the end of the +// queue, filling up the missing intermediate entries as necessary. Returns +// true if the element has been inserted successfully, false if it was already +// in the queue or inserted out of order. +func (p *packetNumberIndexedQueue[T]) Emplace(packetNumber congestion.PacketNumber, entry *T) bool { + if packetNumber == invalidPacketNumber || entry == nil { + return false + } + + if p.IsEmpty() { + p.entries.PushBack(entryWrapper[T]{ + present: true, + entry: *entry, + }) + p.numberOfPresentEntries = 1 + p.firstPacket = packetNumber + return true + } + + // Do not allow insertion out-of-order. + if packetNumber <= p.LastPacket() { + return false + } + + // Handle potentially missing elements. + offset := int(packetNumber - p.FirstPacket()) + if gap := offset - p.entries.Len(); gap > 0 { + for i := 0; i < gap; i++ { + p.entries.PushBack(entryWrapper[T]{}) + } + } + + p.entries.PushBack(entryWrapper[T]{ + present: true, + entry: *entry, + }) + p.numberOfPresentEntries++ + return true +} + +// GetEntry Retrieve the entry associated with the packet number. Returns the pointer +// to the entry in case of success, or nullptr if the entry does not exist. +func (p *packetNumberIndexedQueue[T]) GetEntry(packetNumber congestion.PacketNumber) *T { + ew := p.getEntryWraper(packetNumber) + if ew == nil { + return nil + } + + return &ew.entry +} + +// Remove, Same as above, but if an entry is present in the queue, also call f(entry) +// before removing it. +func (p *packetNumberIndexedQueue[T]) Remove(packetNumber congestion.PacketNumber, f func(T)) bool { + ew := p.getEntryWraper(packetNumber) + if ew == nil { + return false + } + if f != nil { + f(ew.entry) + } + ew.present = false + p.numberOfPresentEntries-- + + if packetNumber == p.FirstPacket() { + p.clearup() + } + + return true +} + +// RemoveUpTo, but not including |packet_number|. +// Unused slots in the front are also removed, which means when the function +// returns, |first_packet()| can be larger than |packet_number|. +func (p *packetNumberIndexedQueue[T]) RemoveUpTo(packetNumber congestion.PacketNumber) { + for !p.entries.Empty() && + p.firstPacket != invalidPacketNumber && + p.firstPacket < packetNumber { + if p.entries.Front().present { + p.numberOfPresentEntries-- + } + p.entries.PopFront() + p.firstPacket++ + } + p.clearup() + + return +} + +// IsEmpty return if queue is empty. +func (p *packetNumberIndexedQueue[T]) IsEmpty() bool { + return p.numberOfPresentEntries == 0 +} + +// NumberOfPresentEntries returns the number of entries in the queue. +func (p *packetNumberIndexedQueue[T]) NumberOfPresentEntries() int { + return p.numberOfPresentEntries +} + +// EntrySlotsUsed returns the number of entries allocated in the underlying deque. This is +// proportional to the memory usage of the queue. +func (p *packetNumberIndexedQueue[T]) EntrySlotsUsed() int { + return p.entries.Len() +} + +// LastPacket returns packet number of the first entry in the queue. +func (p *packetNumberIndexedQueue[T]) FirstPacket() (packetNumber congestion.PacketNumber) { + return p.firstPacket +} + +// LastPacket returns packet number of the last entry ever inserted in the queue. Note that the +// entry in question may have already been removed. Zero if the queue is +// empty. +func (p *packetNumberIndexedQueue[T]) LastPacket() (packetNumber congestion.PacketNumber) { + if p.IsEmpty() { + return invalidPacketNumber + } + + return p.firstPacket + congestion.PacketNumber(p.entries.Len()-1) +} + +func (p *packetNumberIndexedQueue[T]) clearup() { + for !p.entries.Empty() && !p.entries.Front().present { + p.entries.PopFront() + p.firstPacket++ + } + if p.entries.Empty() { + p.firstPacket = invalidPacketNumber + } +} + +func (p *packetNumberIndexedQueue[T]) getEntryWraper(packetNumber congestion.PacketNumber) *entryWrapper[T] { + if packetNumber == invalidPacketNumber || + p.IsEmpty() || + packetNumber < p.firstPacket { + return nil + } + + offset := int(packetNumber - p.firstPacket) + if offset >= p.entries.Len() { + return nil + } + + ew := p.entries.Offset(offset) + if ew == nil || !ew.present { + return nil + } + + return ew +} diff --git a/transport/hysteria2/core/internal/congestion/bbr/ringbuffer.go b/transport/hysteria2/core/internal/congestion/bbr/ringbuffer.go new file mode 100644 index 00000000..ed92d4ce --- /dev/null +++ b/transport/hysteria2/core/internal/congestion/bbr/ringbuffer.go @@ -0,0 +1,118 @@ +package bbr + +// A RingBuffer is a ring buffer. +// It acts as a heap that doesn't cause any allocations. +type RingBuffer[T any] struct { + ring []T + headPos, tailPos int + full bool +} + +// Init preallocs a buffer with a certain size. +func (r *RingBuffer[T]) Init(size int) { + r.ring = make([]T, size) +} + +// Len returns the number of elements in the ring buffer. +func (r *RingBuffer[T]) Len() int { + if r.full { + return len(r.ring) + } + if r.tailPos >= r.headPos { + return r.tailPos - r.headPos + } + return r.tailPos - r.headPos + len(r.ring) +} + +// Empty says if the ring buffer is empty. +func (r *RingBuffer[T]) Empty() bool { + return !r.full && r.headPos == r.tailPos +} + +// PushBack adds a new element. +// If the ring buffer is full, its capacity is increased first. +func (r *RingBuffer[T]) PushBack(t T) { + if r.full || len(r.ring) == 0 { + r.grow() + } + r.ring[r.tailPos] = t + r.tailPos++ + if r.tailPos == len(r.ring) { + r.tailPos = 0 + } + if r.tailPos == r.headPos { + r.full = true + } +} + +// PopFront returns the next element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) PopFront() T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue") + } + r.full = false + t := r.ring[r.headPos] + r.ring[r.headPos] = *new(T) + r.headPos++ + if r.headPos == len(r.ring) { + r.headPos = 0 + } + return t +} + +// Offset returns the offset element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first +// and check if the index larger than buffer length. +func (r *RingBuffer[T]) Offset(index int) *T { + if r.Empty() || index >= r.Len() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: offset from invalid index") + } + offset := (r.headPos + index) % len(r.ring) + return &r.ring[offset] +} + +// Front returns the front element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) Front() *T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: front from an empty queue") + } + return &r.ring[r.headPos] +} + +// Back returns the back element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) Back() *T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: back from an empty queue") + } + return r.Offset(r.Len() - 1) +} + +// Grow the maximum size of the queue. +// This method assume the queue is full. +func (r *RingBuffer[T]) grow() { + oldRing := r.ring + newSize := len(oldRing) * 2 + if newSize == 0 { + newSize = 1 + } + r.ring = make([]T, newSize) + headLen := copy(r.ring, oldRing[r.headPos:]) + copy(r.ring[headLen:], oldRing[:r.headPos]) + r.headPos, r.tailPos, r.full = 0, len(oldRing), false +} + +// Clear removes all elements. +func (r *RingBuffer[T]) Clear() { + var zeroValue T + for i := range r.ring { + r.ring[i] = zeroValue + } + r.headPos, r.tailPos, r.full = 0, 0, false +} diff --git a/transport/hysteria2/core/internal/congestion/bbr/windowed_filter.go b/transport/hysteria2/core/internal/congestion/bbr/windowed_filter.go new file mode 100644 index 00000000..4773bce5 --- /dev/null +++ b/transport/hysteria2/core/internal/congestion/bbr/windowed_filter.go @@ -0,0 +1,162 @@ +package bbr + +import ( + "golang.org/x/exp/constraints" +) + +// Implements Kathleen Nichols' algorithm for tracking the minimum (or maximum) +// estimate of a stream of samples over some fixed time interval. (E.g., +// the minimum RTT over the past five minutes.) The algorithm keeps track of +// the best, second best, and third best min (or max) estimates, maintaining an +// invariant that the measurement time of the n'th best >= n-1'th best. + +// The algorithm works as follows. On a reset, all three estimates are set to +// the same sample. The second best estimate is then recorded in the second +// quarter of the window, and a third best estimate is recorded in the second +// half of the window, bounding the worst case error when the true min is +// monotonically increasing (or true max is monotonically decreasing) over the +// window. +// +// A new best sample replaces all three estimates, since the new best is lower +// (or higher) than everything else in the window and it is the most recent. +// The window thus effectively gets reset on every new min. The same property +// holds true for second best and third best estimates. Specifically, when a +// sample arrives that is better than the second best but not better than the +// best, it replaces the second and third best estimates but not the best +// estimate. Similarly, a sample that is better than the third best estimate +// but not the other estimates replaces only the third best estimate. +// +// Finally, when the best expires, it is replaced by the second best, which in +// turn is replaced by the third best. The newest sample replaces the third +// best. + +type WindowedFilterValue interface { + any +} + +type WindowedFilterTime interface { + constraints.Integer | constraints.Float +} + +type WindowedFilter[V WindowedFilterValue, T WindowedFilterTime] struct { + // Time length of window. + windowLength T + estimates []entry[V, T] + comparator func(V, V) int +} + +type entry[V WindowedFilterValue, T WindowedFilterTime] struct { + sample V + time T +} + +// Compares two values and returns true if the first is greater than or equal +// to the second. +func MaxFilter[O constraints.Ordered](a, b O) int { + if a > b { + return 1 + } else if a < b { + return -1 + } + return 0 +} + +// Compares two values and returns true if the first is less than or equal +// to the second. +func MinFilter[O constraints.Ordered](a, b O) int { + if a < b { + return 1 + } else if a > b { + return -1 + } + return 0 +} + +func NewWindowedFilter[V WindowedFilterValue, T WindowedFilterTime](windowLength T, comparator func(V, V) int) *WindowedFilter[V, T] { + return &WindowedFilter[V, T]{ + windowLength: windowLength, + estimates: make([]entry[V, T], 3, 3), + comparator: comparator, + } +} + +// Changes the window length. Does not update any current samples. +func (f *WindowedFilter[V, T]) SetWindowLength(windowLength T) { + f.windowLength = windowLength +} + +func (f *WindowedFilter[V, T]) GetBest() V { + return f.estimates[0].sample +} + +func (f *WindowedFilter[V, T]) GetSecondBest() V { + return f.estimates[1].sample +} + +func (f *WindowedFilter[V, T]) GetThirdBest() V { + return f.estimates[2].sample +} + +// Updates best estimates with |sample|, and expires and updates best +// estimates as necessary. +func (f *WindowedFilter[V, T]) Update(newSample V, newTime T) { + // Reset all estimates if they have not yet been initialized, if new sample + // is a new best, or if the newest recorded estimate is too old. + if f.comparator(f.estimates[0].sample, *new(V)) == 0 || + f.comparator(newSample, f.estimates[0].sample) >= 0 || + newTime-f.estimates[2].time > f.windowLength { + f.Reset(newSample, newTime) + return + } + + if f.comparator(newSample, f.estimates[1].sample) >= 0 { + f.estimates[1] = entry[V, T]{newSample, newTime} + f.estimates[2] = f.estimates[1] + } else if f.comparator(newSample, f.estimates[2].sample) >= 0 { + f.estimates[2] = entry[V, T]{newSample, newTime} + } + + // Expire and update estimates as necessary. + if newTime-f.estimates[0].time > f.windowLength { + // The best estimate hasn't been updated for an entire window, so promote + // second and third best estimates. + f.estimates[0] = f.estimates[1] + f.estimates[1] = f.estimates[2] + f.estimates[2] = entry[V, T]{newSample, newTime} + // Need to iterate one more time. Check if the new best estimate is + // outside the window as well, since it may also have been recorded a + // long time ago. Don't need to iterate once more since we cover that + // case at the beginning of the method. + if newTime-f.estimates[0].time > f.windowLength { + f.estimates[0] = f.estimates[1] + f.estimates[1] = f.estimates[2] + } + return + } + if f.comparator(f.estimates[1].sample, f.estimates[0].sample) == 0 && + newTime-f.estimates[1].time > f.windowLength/4 { + // A quarter of the window has passed without a better sample, so the + // second-best estimate is taken from the second quarter of the window. + f.estimates[1] = entry[V, T]{newSample, newTime} + f.estimates[2] = f.estimates[1] + return + } + + if f.comparator(f.estimates[2].sample, f.estimates[1].sample) == 0 && + newTime-f.estimates[2].time > f.windowLength/2 { + // We've passed a half of the window without a better estimate, so take + // a third-best estimate from the second half of the window. + f.estimates[2] = entry[V, T]{newSample, newTime} + } +} + +// Resets all estimates to new sample. +func (f *WindowedFilter[V, T]) Reset(newSample V, newTime T) { + f.estimates[2] = entry[V, T]{newSample, newTime} + f.estimates[1] = f.estimates[2] + f.estimates[0] = f.estimates[1] +} + +func (f *WindowedFilter[V, T]) Clear() { + f.estimates = make([]entry[V, T], 3, 3) +} diff --git a/transport/hysteria2/core/internal/congestion/brutal/brutal.go b/transport/hysteria2/core/internal/congestion/brutal/brutal.go new file mode 100644 index 00000000..e5792027 --- /dev/null +++ b/transport/hysteria2/core/internal/congestion/brutal/brutal.go @@ -0,0 +1,181 @@ +package brutal + +import ( + "fmt" + "os" + "strconv" + "time" + + "github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion/common" + + "github.com/metacubex/quic-go/congestion" +) + +const ( + pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample + minSampleCount = 50 + minAckRate = 0.8 + congestionWindowMultiplier = 2 + + debugEnv = "HYSTERIA_BRUTAL_DEBUG" + debugPrintInterval = 2 +) + +var _ congestion.CongestionControl = &BrutalSender{} + +type BrutalSender struct { + rttStats congestion.RTTStatsProvider + bps congestion.ByteCount + maxDatagramSize congestion.ByteCount + pacer *common.Pacer + + pktInfoSlots [pktInfoSlotCount]pktInfo + ackRate float64 + + debug bool + lastAckPrintTimestamp int64 +} + +type pktInfo struct { + Timestamp int64 + AckCount uint64 + LossCount uint64 +} + +func NewBrutalSender(bps uint64) *BrutalSender { + debug, _ := strconv.ParseBool(os.Getenv(debugEnv)) + bs := &BrutalSender{ + bps: congestion.ByteCount(bps), + maxDatagramSize: congestion.InitialPacketSizeIPv4, + ackRate: 1, + debug: debug, + } + bs.pacer = common.NewPacer(func() congestion.ByteCount { + return congestion.ByteCount(float64(bs.bps) / bs.ackRate) + }) + return bs +} + +func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) { + b.rttStats = rttStats +} + +func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { + return b.pacer.TimeUntilSend() +} + +func (b *BrutalSender) HasPacingBudget(now time.Time) bool { + return b.pacer.Budget(now) >= b.maxDatagramSize +} + +func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool { + return bytesInFlight < b.GetCongestionWindow() +} + +func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount { + rtt := b.rttStats.SmoothedRTT() + if rtt <= 0 { + return 10240 + } + return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate) +} + +func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, + packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool, +) { + b.pacer.SentPacket(sentTime, bytes) +} + +func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, eventTime time.Time, +) { + // Stub +} + +func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, +) { + // Stub +} + +func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { + currentTimestamp := eventTime.Unix() + slot := currentTimestamp % pktInfoSlotCount + if b.pktInfoSlots[slot].Timestamp == currentTimestamp { + b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets)) + b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets)) + } else { + // uninitialized slot or too old, reset + b.pktInfoSlots[slot].Timestamp = currentTimestamp + b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets)) + b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets)) + } + b.updateAckRate(currentTimestamp) +} + +func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) { + b.maxDatagramSize = size + b.pacer.SetMaxDatagramSize(size) + if b.debug { + b.debugPrint("SetMaxDatagramSize: %d", size) + } +} + +func (b *BrutalSender) updateAckRate(currentTimestamp int64) { + minTimestamp := currentTimestamp - pktInfoSlotCount + var ackCount, lossCount uint64 + for _, info := range b.pktInfoSlots { + if info.Timestamp < minTimestamp { + continue + } + ackCount += info.AckCount + lossCount += info.LossCount + } + if ackCount+lossCount < minSampleCount { + b.ackRate = 1 + if b.canPrintAckRate(currentTimestamp) { + b.lastAckPrintTimestamp = currentTimestamp + b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)", + ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) + } + return + } + rate := float64(ackCount) / float64(ackCount+lossCount) + if rate < minAckRate { + b.ackRate = minAckRate + if b.canPrintAckRate(currentTimestamp) { + b.lastAckPrintTimestamp = currentTimestamp + b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)", + rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) + } + return + } + b.ackRate = rate + if b.canPrintAckRate(currentTimestamp) { + b.lastAckPrintTimestamp = currentTimestamp + b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)", + rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) + } +} + +func (b *BrutalSender) InSlowStart() bool { + return false +} + +func (b *BrutalSender) InRecovery() bool { + return false +} + +func (b *BrutalSender) MaybeExitSlowStart() {} + +func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {} + +func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool { + return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval +} + +func (b *BrutalSender) debugPrint(format string, a ...any) { + fmt.Printf("[BrutalSender] [%s] %s\n", + time.Now().Format("15:04:05"), + fmt.Sprintf(format, a...)) +} diff --git a/transport/hysteria2/core/internal/congestion/common/pacer.go b/transport/hysteria2/core/internal/congestion/common/pacer.go new file mode 100644 index 00000000..e1ec9f6c --- /dev/null +++ b/transport/hysteria2/core/internal/congestion/common/pacer.go @@ -0,0 +1,95 @@ +package common + +import ( + "math" + "time" + + "github.com/metacubex/quic-go/congestion" +) + +const ( + maxBurstPackets = 10 +) + +// Pacer implements a token bucket pacing algorithm. +type Pacer struct { + budgetAtLastSent congestion.ByteCount + maxDatagramSize congestion.ByteCount + lastSentTime time.Time + getBandwidth func() congestion.ByteCount // in bytes/s +} + +func NewPacer(getBandwidth func() congestion.ByteCount) *Pacer { + p := &Pacer{ + budgetAtLastSent: maxBurstPackets * congestion.InitialPacketSizeIPv4, + maxDatagramSize: congestion.InitialPacketSizeIPv4, + getBandwidth: getBandwidth, + } + return p +} + +func (p *Pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) { + budget := p.Budget(sendTime) + if size > budget { + p.budgetAtLastSent = 0 + } else { + p.budgetAtLastSent = budget - size + } + p.lastSentTime = sendTime +} + +func (p *Pacer) Budget(now time.Time) congestion.ByteCount { + if p.lastSentTime.IsZero() { + return p.maxBurstSize() + } + budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + if budget < 0 { // protect against overflows + budget = congestion.ByteCount(1<<62 - 1) + } + return minByteCount(p.maxBurstSize(), budget) +} + +func (p *Pacer) maxBurstSize() congestion.ByteCount { + return maxByteCount( + congestion.ByteCount((congestion.MinPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9, + maxBurstPackets*p.maxDatagramSize, + ) +} + +// TimeUntilSend returns when the next packet should be sent. +// It returns the zero value of time.Time if a packet can be sent immediately. +func (p *Pacer) TimeUntilSend() time.Time { + if p.budgetAtLastSent >= p.maxDatagramSize { + return time.Time{} + } + return p.lastSentTime.Add(maxDuration( + congestion.MinPacingDelay, + time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/ + float64(p.getBandwidth())))*time.Nanosecond, + )) +} + +func (p *Pacer) SetMaxDatagramSize(s congestion.ByteCount) { + p.maxDatagramSize = s +} + +func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount { + if a < b { + return b + } + return a +} + +func minByteCount(a, b congestion.ByteCount) congestion.ByteCount { + if a < b { + return a + } + return b +} + +func maxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} diff --git a/transport/hysteria2/core/internal/congestion/utils.go b/transport/hysteria2/core/internal/congestion/utils.go new file mode 100644 index 00000000..724c5cb9 --- /dev/null +++ b/transport/hysteria2/core/internal/congestion/utils.go @@ -0,0 +1,18 @@ +package congestion + +import ( + "github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion/bbr" + "github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion/brutal" + "github.com/metacubex/quic-go" +) + +func UseBBR(conn quic.Connection) { + conn.SetCongestionControl(bbr.NewBbrSender( + bbr.DefaultClock{}, + bbr.GetInitialPacketSize(conn.RemoteAddr()), + )) +} + +func UseBrutal(conn quic.Connection, tx uint64) { + conn.SetCongestionControl(brutal.NewBrutalSender(tx)) +} diff --git a/transport/hysteria2/core/internal/frag/frag.go b/transport/hysteria2/core/internal/frag/frag.go new file mode 100644 index 00000000..7653f81f --- /dev/null +++ b/transport/hysteria2/core/internal/frag/frag.go @@ -0,0 +1,77 @@ +package frag + +import ( + "github.com/metacubex/mihomo/transport/hysteria2/core/internal/protocol" +) + +func FragUDPMessage(m *protocol.UDPMessage, maxSize int) []protocol.UDPMessage { + if m.Size() <= maxSize { + return []protocol.UDPMessage{*m} + } + fullPayload := m.Data + maxPayloadSize := maxSize - m.HeaderSize() + off := 0 + fragID := uint8(0) + fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up + frags := make([]protocol.UDPMessage, fragCount) + for off < len(fullPayload) { + payloadSize := len(fullPayload) - off + if payloadSize > maxPayloadSize { + payloadSize = maxPayloadSize + } + frag := *m + frag.FragID = fragID + frag.FragCount = fragCount + frag.Data = fullPayload[off : off+payloadSize] + frags[fragID] = frag + off += payloadSize + fragID++ + } + return frags +} + +// Defragger handles the defragmentation of UDP messages. +// The current implementation can only handle one packet ID at a time. +// If another packet arrives before a packet has received all fragments +// in their entirety, any previous state is discarded. +type Defragger struct { + pktID uint16 + frags []*protocol.UDPMessage + count uint8 + size int // data size +} + +func (d *Defragger) Feed(m *protocol.UDPMessage) *protocol.UDPMessage { + if m.FragCount <= 1 { + return m + } + if m.FragID >= m.FragCount { + // wtf is this? + return nil + } + if m.PacketID != d.pktID || m.FragCount != uint8(len(d.frags)) { + // new message, clear previous state + d.pktID = m.PacketID + d.frags = make([]*protocol.UDPMessage, m.FragCount) + d.frags[m.FragID] = m + d.count = 1 + d.size = len(m.Data) + } else if d.frags[m.FragID] == nil { + d.frags[m.FragID] = m + d.count++ + d.size += len(m.Data) + if int(d.count) == len(d.frags) { + // all fragments received, assemble + data := make([]byte, d.size) + off := 0 + for _, frag := range d.frags { + off += copy(data[off:], frag.Data) + } + m.Data = data + m.FragID = 0 + m.FragCount = 1 + return m + } + } + return nil +} diff --git a/transport/hysteria2/core/internal/pmtud/avail.go b/transport/hysteria2/core/internal/pmtud/avail.go new file mode 100644 index 00000000..cd7afd01 --- /dev/null +++ b/transport/hysteria2/core/internal/pmtud/avail.go @@ -0,0 +1,7 @@ +//go:build linux || windows || darwin + +package pmtud + +const ( + DisablePathMTUDiscovery = false +) diff --git a/transport/hysteria2/core/internal/pmtud/unavail.go b/transport/hysteria2/core/internal/pmtud/unavail.go new file mode 100644 index 00000000..917b973a --- /dev/null +++ b/transport/hysteria2/core/internal/pmtud/unavail.go @@ -0,0 +1,13 @@ +//go:build !linux && !windows && !darwin + +package pmtud + +// quic-go's MTU detection is enabled by default on all platforms. +// However, it only actually sets the DF bit on 3 supported platforms (Windows, macOS, Linux). +// As a result, on other platforms, probe packets that should never be fragmented will still +// be fragmented and transmitted. So we're only enabling it for platforms where we've verified +// its functionality for now. + +const ( + DisablePathMTUDiscovery = true +) diff --git a/transport/hysteria2/core/internal/protocol/http.go b/transport/hysteria2/core/internal/protocol/http.go new file mode 100644 index 00000000..abcc1a4f --- /dev/null +++ b/transport/hysteria2/core/internal/protocol/http.go @@ -0,0 +1,68 @@ +package protocol + +import ( + "net/http" + "strconv" +) + +const ( + URLHost = "hysteria" + URLPath = "/auth" + + RequestHeaderAuth = "Hysteria-Auth" + ResponseHeaderUDPEnabled = "Hysteria-UDP" + CommonHeaderCCRX = "Hysteria-CC-RX" + CommonHeaderPadding = "Hysteria-Padding" + + StatusAuthOK = 233 +) + +// AuthRequest is what client sends to server for authentication. +type AuthRequest struct { + Auth string + Rx uint64 // 0 = unknown, client asks server to use bandwidth detection +} + +// AuthResponse is what server sends to client when authentication is passed. +type AuthResponse struct { + UDPEnabled bool + Rx uint64 // 0 = unlimited + RxAuto bool // true = server asks client to use bandwidth detection +} + +func AuthRequestFromHeader(h http.Header) AuthRequest { + rx, _ := strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64) + return AuthRequest{ + Auth: h.Get(RequestHeaderAuth), + Rx: rx, + } +} + +func AuthRequestToHeader(h http.Header, req AuthRequest) { + h.Set(RequestHeaderAuth, req.Auth) + h.Set(CommonHeaderCCRX, strconv.FormatUint(req.Rx, 10)) + h.Set(CommonHeaderPadding, authRequestPadding.String()) +} + +func AuthResponseFromHeader(h http.Header) AuthResponse { + resp := AuthResponse{} + resp.UDPEnabled, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled)) + rxStr := h.Get(CommonHeaderCCRX) + if rxStr == "auto" { + // Special case for server requesting client to use bandwidth detection + resp.RxAuto = true + } else { + resp.Rx, _ = strconv.ParseUint(rxStr, 10, 64) + } + return resp +} + +func AuthResponseToHeader(h http.Header, resp AuthResponse) { + h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(resp.UDPEnabled)) + if resp.RxAuto { + h.Set(CommonHeaderCCRX, "auto") + } else { + h.Set(CommonHeaderCCRX, strconv.FormatUint(resp.Rx, 10)) + } + h.Set(CommonHeaderPadding, authResponsePadding.String()) +} diff --git a/transport/hysteria2/core/internal/protocol/padding.go b/transport/hysteria2/core/internal/protocol/padding.go new file mode 100644 index 00000000..9895cdcc --- /dev/null +++ b/transport/hysteria2/core/internal/protocol/padding.go @@ -0,0 +1,31 @@ +package protocol + +import ( + "math/rand" +) + +const ( + paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +) + +// padding specifies a half-open range [Min, Max). +type padding struct { + Min int + Max int +} + +func (p padding) String() string { + n := p.Min + rand.Intn(p.Max-p.Min) + bs := make([]byte, n) + for i := range bs { + bs[i] = paddingChars[rand.Intn(len(paddingChars))] + } + return string(bs) +} + +var ( + authRequestPadding = padding{Min: 256, Max: 2048} + authResponsePadding = padding{Min: 256, Max: 2048} + tcpRequestPadding = padding{Min: 64, Max: 512} + tcpResponsePadding = padding{Min: 128, Max: 1024} +) diff --git a/transport/hysteria2/core/internal/protocol/proxy.go b/transport/hysteria2/core/internal/protocol/proxy.go new file mode 100644 index 00000000..c2e046ed --- /dev/null +++ b/transport/hysteria2/core/internal/protocol/proxy.go @@ -0,0 +1,255 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/metacubex/mihomo/transport/hysteria2/core/errors" + + "github.com/metacubex/quic-go/quicvarint" +) + +const ( + FrameTypeTCPRequest = 0x401 + + // Max length values are for preventing DoS attacks + + MaxAddressLength = 2048 + MaxMessageLength = 2048 + MaxPaddingLength = 4096 + + MaxUDPSize = 4096 + + maxVarInt1 = 63 + maxVarInt2 = 16383 + maxVarInt4 = 1073741823 + maxVarInt8 = 4611686018427387903 +) + +// TCPRequest format: +// 0x401 (QUIC varint) +// Address length (QUIC varint) +// Address (bytes) +// Padding length (QUIC varint) +// Padding (bytes) + +func ReadTCPRequest(r io.Reader) (string, error) { + bReader := quicvarint.NewReader(r) + addrLen, err := quicvarint.Read(bReader) + if err != nil { + return "", err + } + if addrLen == 0 || addrLen > MaxAddressLength { + return "", errors.ProtocolError{Message: "invalid address length"} + } + addrBuf := make([]byte, addrLen) + _, err = io.ReadFull(r, addrBuf) + if err != nil { + return "", err + } + paddingLen, err := quicvarint.Read(bReader) + if err != nil { + return "", err + } + if paddingLen > MaxPaddingLength { + return "", errors.ProtocolError{Message: "invalid padding length"} + } + if paddingLen > 0 { + _, err = io.CopyN(io.Discard, r, int64(paddingLen)) + if err != nil { + return "", err + } + } + return string(addrBuf), nil +} + +func WriteTCPRequest(w io.Writer, addr string) error { + padding := tcpRequestPadding.String() + paddingLen := len(padding) + addrLen := len(addr) + sz := int(quicvarint.Len(FrameTypeTCPRequest)) + + int(quicvarint.Len(uint64(addrLen))) + addrLen + + int(quicvarint.Len(uint64(paddingLen))) + paddingLen + buf := make([]byte, sz) + i := varintPut(buf, FrameTypeTCPRequest) + i += varintPut(buf[i:], uint64(addrLen)) + i += copy(buf[i:], addr) + i += varintPut(buf[i:], uint64(paddingLen)) + copy(buf[i:], padding) + _, err := w.Write(buf) + return err +} + +// TCPResponse format: +// Status (byte, 0=ok, 1=error) +// Message length (QUIC varint) +// Message (bytes) +// Padding length (QUIC varint) +// Padding (bytes) + +func ReadTCPResponse(r io.Reader) (bool, string, error) { + var status [1]byte + if _, err := io.ReadFull(r, status[:]); err != nil { + return false, "", err + } + bReader := quicvarint.NewReader(r) + msgLen, err := quicvarint.Read(bReader) + if err != nil { + return false, "", err + } + if msgLen > MaxMessageLength { + return false, "", errors.ProtocolError{Message: "invalid message length"} + } + var msgBuf []byte + // No message is fine + if msgLen > 0 { + msgBuf = make([]byte, msgLen) + _, err = io.ReadFull(r, msgBuf) + if err != nil { + return false, "", err + } + } + paddingLen, err := quicvarint.Read(bReader) + if err != nil { + return false, "", err + } + if paddingLen > MaxPaddingLength { + return false, "", errors.ProtocolError{Message: "invalid padding length"} + } + if paddingLen > 0 { + _, err = io.CopyN(io.Discard, r, int64(paddingLen)) + if err != nil { + return false, "", err + } + } + return status[0] == 0, string(msgBuf), nil +} + +func WriteTCPResponse(w io.Writer, ok bool, msg string) error { + padding := tcpResponsePadding.String() + paddingLen := len(padding) + msgLen := len(msg) + sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen + + int(quicvarint.Len(uint64(paddingLen))) + paddingLen + buf := make([]byte, sz) + if ok { + buf[0] = 0 + } else { + buf[0] = 1 + } + i := varintPut(buf[1:], uint64(msgLen)) + i += copy(buf[1+i:], msg) + i += varintPut(buf[1+i:], uint64(paddingLen)) + copy(buf[1+i:], padding) + _, err := w.Write(buf) + return err +} + +// UDPMessage format: +// Session ID (uint32 BE) +// Packet ID (uint16 BE) +// Fragment ID (uint8) +// Fragment count (uint8) +// Address length (QUIC varint) +// Address (bytes) +// Data... + +type UDPMessage struct { + SessionID uint32 // 4 + PacketID uint16 // 2 + FragID uint8 // 1 + FragCount uint8 // 1 + Addr string // varint + bytes + Data []byte +} + +func (m *UDPMessage) HeaderSize() int { + lAddr := len(m.Addr) + return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr +} + +func (m *UDPMessage) Size() int { + return m.HeaderSize() + len(m.Data) +} + +func (m *UDPMessage) Serialize(buf []byte) int { + // Make sure the buffer is big enough + if len(buf) < m.Size() { + return -1 + } + binary.BigEndian.PutUint32(buf, m.SessionID) + binary.BigEndian.PutUint16(buf[4:], m.PacketID) + buf[6] = m.FragID + buf[7] = m.FragCount + i := varintPut(buf[8:], uint64(len(m.Addr))) + i += copy(buf[8+i:], m.Addr) + i += copy(buf[8+i:], m.Data) + return 8 + i +} + +func ParseUDPMessage(msg []byte) (*UDPMessage, error) { + m := &UDPMessage{} + buf := bytes.NewBuffer(msg) + if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil { + return nil, err + } + lAddr, err := quicvarint.Read(buf) + if err != nil { + return nil, err + } + if lAddr == 0 || lAddr > MaxMessageLength { + return nil, errors.ProtocolError{Message: "invalid address length"} + } + bs := buf.Bytes() + if len(bs) <= int(lAddr) { + // We use <= instead of < here as we expect at least one byte of data after the address + return nil, errors.ProtocolError{Message: "invalid message length"} + } + m.Addr = string(bs[:lAddr]) + m.Data = bs[lAddr:] + return m, nil +} + +// varintPut is like quicvarint.Append, but instead of appending to a slice, +// it writes to a fixed-size buffer. Returns the number of bytes written. +func varintPut(b []byte, i uint64) int { + if i <= maxVarInt1 { + b[0] = uint8(i) + return 1 + } + if i <= maxVarInt2 { + b[0] = uint8(i>>8) | 0x40 + b[1] = uint8(i) + return 2 + } + if i <= maxVarInt4 { + b[0] = uint8(i>>24) | 0x80 + b[1] = uint8(i >> 16) + b[2] = uint8(i >> 8) + b[3] = uint8(i) + return 4 + } + if i <= maxVarInt8 { + b[0] = uint8(i>>56) | 0xc0 + b[1] = uint8(i >> 48) + b[2] = uint8(i >> 40) + b[3] = uint8(i >> 32) + b[4] = uint8(i >> 24) + b[5] = uint8(i >> 16) + b[6] = uint8(i >> 8) + b[7] = uint8(i) + return 8 + } + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) +} diff --git a/transport/hysteria2/core/internal/utils/atomic.go b/transport/hysteria2/core/internal/utils/atomic.go new file mode 100644 index 00000000..e3c3d977 --- /dev/null +++ b/transport/hysteria2/core/internal/utils/atomic.go @@ -0,0 +1,24 @@ +package utils + +import ( + "sync/atomic" + "time" +) + +type AtomicTime struct { + v atomic.Value +} + +func NewAtomicTime(t time.Time) *AtomicTime { + a := &AtomicTime{} + a.Set(t) + return a +} + +func (t *AtomicTime) Set(new time.Time) { + t.v.Store(new) +} + +func (t *AtomicTime) Get() time.Time { + return t.v.Load().(time.Time) +} diff --git a/transport/hysteria2/core/internal/utils/qstream.go b/transport/hysteria2/core/internal/utils/qstream.go new file mode 100644 index 00000000..392689e0 --- /dev/null +++ b/transport/hysteria2/core/internal/utils/qstream.go @@ -0,0 +1,62 @@ +package utils + +import ( + "context" + "time" + + "github.com/metacubex/quic-go" +) + +// QStream is a wrapper of quic.Stream that handles Close() in a way that +// makes more sense to us. By default, quic.Stream's Close() only closes +// the write side of the stream, not the read side. And if there is unread +// data, the stream is not really considered closed until either the data +// is drained or CancelRead() is called. +// References: +// - https://github.com/libp2p/go-libp2p/blob/master/p2p/transport/quic/stream.go +// - https://github.com/quic-go/quic-go/issues/3558 +// - https://github.com/quic-go/quic-go/issues/1599 +type QStream struct { + Stream quic.Stream +} + +func (s *QStream) StreamID() quic.StreamID { + return s.Stream.StreamID() +} + +func (s *QStream) Read(p []byte) (n int, err error) { + return s.Stream.Read(p) +} + +func (s *QStream) CancelRead(code quic.StreamErrorCode) { + s.Stream.CancelRead(code) +} + +func (s *QStream) SetReadDeadline(t time.Time) error { + return s.Stream.SetReadDeadline(t) +} + +func (s *QStream) Write(p []byte) (n int, err error) { + return s.Stream.Write(p) +} + +func (s *QStream) Close() error { + s.Stream.CancelRead(0) + return s.Stream.Close() +} + +func (s *QStream) CancelWrite(code quic.StreamErrorCode) { + s.Stream.CancelWrite(code) +} + +func (s *QStream) Context() context.Context { + return s.Stream.Context() +} + +func (s *QStream) SetWriteDeadline(t time.Time) error { + return s.Stream.SetWriteDeadline(t) +} + +func (s *QStream) SetDeadline(t time.Time) error { + return s.Stream.SetDeadline(t) +} diff --git a/transport/hysteria2/extras/correctnet/correctnet.go b/transport/hysteria2/extras/correctnet/correctnet.go new file mode 100644 index 00000000..06098259 --- /dev/null +++ b/transport/hysteria2/extras/correctnet/correctnet.go @@ -0,0 +1,92 @@ +package correctnet + +import ( + "net" + "net/http" + "strings" +) + +func extractIPFamily(ip net.IP) (family string) { + if len(ip) == 0 { + // real family independent wildcard address, such as ":443" + return "" + } + if p4 := ip.To4(); len(p4) == net.IPv4len { + return "4" + } + return "6" +} + +func tcpAddrNetwork(addr *net.TCPAddr) (network string) { + if addr == nil { + return "tcp" + } + return "tcp" + extractIPFamily(addr.IP) +} + +func udpAddrNetwork(addr *net.UDPAddr) (network string) { + if addr == nil { + return "udp" + } + return "udp" + extractIPFamily(addr.IP) +} + +func ipAddrNetwork(addr *net.IPAddr) (network string) { + if addr == nil { + return "ip" + } + return "ip" + extractIPFamily(addr.IP) +} + +func Listen(network, address string) (net.Listener, error) { + if network == "tcp" { + tcpAddr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + return ListenTCP(network, tcpAddr) + } + return net.Listen(network, address) +} + +func ListenTCP(network string, laddr *net.TCPAddr) (*net.TCPListener, error) { + if network == "tcp" { + return net.ListenTCP(tcpAddrNetwork(laddr), laddr) + } + return net.ListenTCP(network, laddr) +} + +func ListenPacket(network, address string) (listener net.PacketConn, err error) { + if network == "udp" { + udpAddr, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + return ListenUDP(network, udpAddr) + } + if strings.HasPrefix(network, "ip:") { + proto := network[3:] + ipAddr, err := net.ResolveIPAddr(proto, address) + if err != nil { + return nil, err + } + return net.ListenIP(ipAddrNetwork(ipAddr)+":"+proto, ipAddr) + } + return net.ListenPacket(network, address) +} + +func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { + if network == "udp" { + return net.ListenUDP(udpAddrNetwork(laddr), laddr) + } + return net.ListenUDP(network, laddr) +} + +func HTTPListenAndServe(address string, handler http.Handler) error { + listener, err := Listen("tcp", address) + if err != nil { + return err + } + defer listener.Close() + return http.Serve(listener, handler) +} diff --git a/transport/hysteria2/extras/obfs/conn.go b/transport/hysteria2/extras/obfs/conn.go new file mode 100644 index 00000000..46131917 --- /dev/null +++ b/transport/hysteria2/extras/obfs/conn.go @@ -0,0 +1,121 @@ +package obfs + +import ( + "net" + "sync" + "syscall" + "time" +) + +const udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough + +// Obfuscator is the interface that wraps the Obfuscate and Deobfuscate methods. +// Both methods return the number of bytes written to out. +// If a packet is not valid, the methods should return 0. +type Obfuscator interface { + Obfuscate(in, out []byte) int + Deobfuscate(in, out []byte) int +} + +var _ net.PacketConn = (*obfsPacketConn)(nil) + +type obfsPacketConn struct { + Conn net.PacketConn + Obfs Obfuscator + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +// obfsPacketConnUDP is a special case of obfsPacketConn that uses a UDPConn +// as the underlying connection. We pass additional methods to quic-go to +// enable UDP-specific optimizations. +type obfsPacketConnUDP struct { + *obfsPacketConn + UDPConn *net.UDPConn +} + +// WrapPacketConn enables obfuscation on a net.PacketConn. +// The obfuscation is transparent to the caller - the n bytes returned by +// ReadFrom and WriteTo are the number of original bytes, not after +// obfuscation/deobfuscation. +func WrapPacketConn(conn net.PacketConn, obfs Obfuscator) net.PacketConn { + opc := &obfsPacketConn{ + Conn: conn, + Obfs: obfs, + readBuf: make([]byte, udpBufferSize), + writeBuf: make([]byte, udpBufferSize), + } + if udpConn, ok := conn.(*net.UDPConn); ok { + return &obfsPacketConnUDP{ + obfsPacketConn: opc, + UDPConn: udpConn, + } + } else { + return opc + } +} + +func (c *obfsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + for { + c.readMutex.Lock() + n, addr, err = c.Conn.ReadFrom(c.readBuf) + if n <= 0 { + c.readMutex.Unlock() + return + } + n = c.Obfs.Deobfuscate(c.readBuf[:n], p) + c.readMutex.Unlock() + if n > 0 || err != nil { + return + } + // Invalid packet, try again + } +} + +func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + c.writeMutex.Lock() + nn := c.Obfs.Obfuscate(p, c.writeBuf) + _, err = c.Conn.WriteTo(c.writeBuf[:nn], addr) + c.writeMutex.Unlock() + if err == nil { + n = len(p) + } + return +} + +func (c *obfsPacketConn) Close() error { + return c.Conn.Close() +} + +func (c *obfsPacketConn) LocalAddr() net.Addr { + return c.Conn.LocalAddr() +} + +func (c *obfsPacketConn) SetDeadline(t time.Time) error { + return c.Conn.SetDeadline(t) +} + +func (c *obfsPacketConn) SetReadDeadline(t time.Time) error { + return c.Conn.SetReadDeadline(t) +} + +func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error { + return c.Conn.SetWriteDeadline(t) +} + +// UDP-specific methods below + +func (c *obfsPacketConnUDP) SetReadBuffer(bytes int) error { + return c.UDPConn.SetReadBuffer(bytes) +} + +func (c *obfsPacketConnUDP) SetWriteBuffer(bytes int) error { + return c.UDPConn.SetWriteBuffer(bytes) +} + +func (c *obfsPacketConnUDP) SyscallConn() (syscall.RawConn, error) { + return c.UDPConn.SyscallConn() +} diff --git a/transport/hysteria2/extras/obfs/salamander.go b/transport/hysteria2/extras/obfs/salamander.go new file mode 100644 index 00000000..50a3ce26 --- /dev/null +++ b/transport/hysteria2/extras/obfs/salamander.go @@ -0,0 +1,71 @@ +package obfs + +import ( + "fmt" + "math/rand" + "sync" + "time" + + "golang.org/x/crypto/blake2b" +) + +const ( + smPSKMinLen = 4 + smSaltLen = 8 + smKeyLen = blake2b.Size256 +) + +var _ Obfuscator = (*SalamanderObfuscator)(nil) + +var ErrPSKTooShort = fmt.Errorf("PSK must be at least %d bytes", smPSKMinLen) + +// SalamanderObfuscator is an obfuscator that obfuscates each packet with +// the BLAKE2b-256 hash of a pre-shared key combined with a random salt. +// Packet format: [8-byte salt][payload] +type SalamanderObfuscator struct { + PSK []byte + RandSrc *rand.Rand + + lk sync.Mutex +} + +func NewSalamanderObfuscator(psk []byte) (*SalamanderObfuscator, error) { + if len(psk) < smPSKMinLen { + return nil, ErrPSKTooShort + } + return &SalamanderObfuscator{ + PSK: psk, + RandSrc: rand.New(rand.NewSource(time.Now().UnixNano())), + }, nil +} + +func (o *SalamanderObfuscator) Obfuscate(in, out []byte) int { + outLen := len(in) + smSaltLen + if len(out) < outLen { + return 0 + } + o.lk.Lock() + _, _ = o.RandSrc.Read(out[:smSaltLen]) + o.lk.Unlock() + key := o.key(out[:smSaltLen]) + for i, c := range in { + out[i+smSaltLen] = c ^ key[i%smKeyLen] + } + return outLen +} + +func (o *SalamanderObfuscator) Deobfuscate(in, out []byte) int { + outLen := len(in) - smSaltLen + if outLen <= 0 || len(out) < outLen { + return 0 + } + key := o.key(in[:smSaltLen]) + for i, c := range in[smSaltLen:] { + out[i] = c ^ key[i%smKeyLen] + } + return outLen +} + +func (o *SalamanderObfuscator) key(salt []byte) [smKeyLen]byte { + return blake2b.Sum256(append(o.PSK, salt...)) +} diff --git a/transport/hysteria2/extras/obfs/salamander_test.go b/transport/hysteria2/extras/obfs/salamander_test.go new file mode 100644 index 00000000..85eafdcc --- /dev/null +++ b/transport/hysteria2/extras/obfs/salamander_test.go @@ -0,0 +1,45 @@ +package obfs + +import ( + "crypto/rand" + "testing" + + "github.com/stretchr/testify/assert" +) + +func BenchmarkSalamanderObfuscator_Obfuscate(b *testing.B) { + o, _ := NewSalamanderObfuscator([]byte("average_password")) + in := make([]byte, 1200) + _, _ = rand.Read(in) + out := make([]byte, 2048) + b.ResetTimer() + for i := 0; i < b.N; i++ { + o.Obfuscate(in, out) + } +} + +func BenchmarkSalamanderObfuscator_Deobfuscate(b *testing.B) { + o, _ := NewSalamanderObfuscator([]byte("average_password")) + in := make([]byte, 1200) + _, _ = rand.Read(in) + out := make([]byte, 2048) + b.ResetTimer() + for i := 0; i < b.N; i++ { + o.Deobfuscate(in, out) + } +} + +func TestSalamanderObfuscator(t *testing.T) { + o, _ := NewSalamanderObfuscator([]byte("average_password")) + in := make([]byte, 1200) + oOut := make([]byte, 2048) + dOut := make([]byte, 2048) + for i := 0; i < 1000; i++ { + _, _ = rand.Read(in) + n := o.Obfuscate(in, oOut) + assert.Equal(t, len(in)+smSaltLen, n) + n = o.Deobfuscate(oOut[:n], dOut) + assert.Equal(t, len(in), n) + assert.Equal(t, in, dOut[:n]) + } +} diff --git a/transport/hysteria2/extras/transport/udphop/addr.go b/transport/hysteria2/extras/transport/udphop/addr.go new file mode 100644 index 00000000..3c704728 --- /dev/null +++ b/transport/hysteria2/extras/transport/udphop/addr.go @@ -0,0 +1,92 @@ +package udphop + +import ( + "fmt" + "net" + "strconv" + "strings" +) + +type InvalidPortError struct { + PortStr string +} + +func (e InvalidPortError) Error() string { + return fmt.Sprintf("%s is not a valid port number or range", e.PortStr) +} + +// UDPHopAddr contains an IP address and a list of ports. +type UDPHopAddr struct { + IP net.IP + Ports []uint16 + PortStr string +} + +func (a *UDPHopAddr) Network() string { + return "udphop" +} + +func (a *UDPHopAddr) String() string { + return net.JoinHostPort(a.IP.String(), a.PortStr) +} + +// addrs returns a list of net.Addr's, one for each port. +func (a *UDPHopAddr) addrs() ([]net.Addr, error) { + var addrs []net.Addr + for _, port := range a.Ports { + addr := &net.UDPAddr{ + IP: a.IP, + Port: int(port), + } + addrs = append(addrs, addr) + } + return addrs, nil +} + +func ResolveUDPHopAddr(addr string) (*UDPHopAddr, error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ip, err := net.ResolveIPAddr("ip", host) + if err != nil { + return nil, err + } + result := &UDPHopAddr{ + IP: ip.IP, + PortStr: portStr, + } + + portStrs := strings.Split(portStr, ",") + for _, portStr := range portStrs { + if strings.Contains(portStr, "-") { + // Port range + portRange := strings.Split(portStr, "-") + if len(portRange) != 2 { + return nil, InvalidPortError{portStr} + } + start, err := strconv.ParseUint(portRange[0], 10, 16) + if err != nil { + return nil, InvalidPortError{portStr} + } + end, err := strconv.ParseUint(portRange[1], 10, 16) + if err != nil { + return nil, InvalidPortError{portStr} + } + if start > end { + start, end = end, start + } + for i := start; i <= end; i++ { + result.Ports = append(result.Ports, uint16(i)) + } + } else { + // Single port + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, InvalidPortError{portStr} + } + result.Ports = append(result.Ports, uint16(port)) + } + } + return result, nil +} diff --git a/transport/hysteria2/extras/transport/udphop/conn.go b/transport/hysteria2/extras/transport/udphop/conn.go new file mode 100644 index 00000000..44f7bf03 --- /dev/null +++ b/transport/hysteria2/extras/transport/udphop/conn.go @@ -0,0 +1,288 @@ +package udphop + +import ( + "errors" + "math/rand" + "net" + "sync" + "syscall" + "time" + + "github.com/metacubex/mihomo/log" +) + +const ( + packetQueueSize = 1024 + udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough + + defaultHopInterval = 30 * time.Second +) + +type udpHopPacketConn struct { + Addr net.Addr + Addrs []net.Addr + HopInterval time.Duration + + connMutex sync.RWMutex + prevConn net.PacketConn + currentConn net.PacketConn + addrIndex int + + readBufferSize int + writeBufferSize int + + recvQueue chan *udpPacket + closeChan chan struct{} + closed bool + + bufPool sync.Pool +} + +type udpPacket struct { + Buf []byte + N int + Addr net.Addr + Err error +} + +func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration) (net.PacketConn, error) { + if hopInterval == 0 { + hopInterval = defaultHopInterval + } else if hopInterval < 5*time.Second { + return nil, errors.New("hop interval must be at least 5 seconds") + } + addrs, err := addr.addrs() + if err != nil { + return nil, err + } + curConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, err + } + hConn := &udpHopPacketConn{ + Addr: addr, + Addrs: addrs, + HopInterval: hopInterval, + prevConn: nil, + currentConn: curConn, + addrIndex: rand.Intn(len(addrs)), + recvQueue: make(chan *udpPacket, packetQueueSize), + closeChan: make(chan struct{}), + bufPool: sync.Pool{ + New: func() interface{} { + return make([]byte, udpBufferSize) + }, + }, + } + go hConn.recvLoop(curConn) + go hConn.hopLoop() + return hConn, nil +} + +func (u *udpHopPacketConn) recvLoop(conn net.PacketConn) { + for { + buf := u.bufPool.Get().([]byte) + n, addr, err := conn.ReadFrom(buf) + if err != nil { + u.bufPool.Put(buf) + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + // Only pass through timeout errors here, not permanent errors + // like connection closed. Connection close is normal as we close + // the old connection to exit this loop every time we hop. + u.recvQueue <- &udpPacket{nil, 0, nil, netErr} + } + return + } + select { + case u.recvQueue <- &udpPacket{buf, n, addr, nil}: + // Packet successfully queued + default: + // Queue is full, drop the packet + u.bufPool.Put(buf) + } + } +} + +func (u *udpHopPacketConn) hopLoop() { + ticker := time.NewTicker(u.HopInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + u.hop() + case <-u.closeChan: + return + } + } +} + +func (u *udpHopPacketConn) hop() { + u.connMutex.Lock() + defer u.connMutex.Unlock() + if u.closed { + return + } + newConn, err := net.ListenUDP("udp", nil) + if err != nil { + // Could be temporary, just skip this hop + return + } + // We need to keep receiving packets from the previous connection, + // because otherwise there will be packet loss due to the time gap + // between we hop to a new port and the server acknowledges this change. + // So we do the following: + // Close prevConn, + // move currentConn to prevConn, + // set newConn as currentConn, + // start recvLoop on newConn. + if u.prevConn != nil { + _ = u.prevConn.Close() // recvLoop for this conn will exit + } + u.prevConn = u.currentConn + u.currentConn = newConn + // Set buffer sizes if previously set + if u.readBufferSize > 0 { + _ = trySetReadBuffer(u.currentConn, u.readBufferSize) + } + if u.writeBufferSize > 0 { + _ = trySetWriteBuffer(u.currentConn, u.writeBufferSize) + } + go u.recvLoop(newConn) + // Update addrIndex to a new random value + u.addrIndex = rand.Intn(len(u.Addrs)) + log.Infoln("hopped to %s", u.Addrs[u.addrIndex]) +} + +func (u *udpHopPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + for { + select { + case p := <-u.recvQueue: + if p.Err != nil { + return 0, nil, p.Err + } + // Currently we do not check whether the packet is from + // the server or not due to performance reasons. + n := copy(b, p.Buf[:p.N]) + u.bufPool.Put(p.Buf) + return n, u.Addr, nil + case <-u.closeChan: + return 0, nil, net.ErrClosed + } + } +} + +func (u *udpHopPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + u.connMutex.RLock() + defer u.connMutex.RUnlock() + if u.closed { + return 0, net.ErrClosed + } + // Skip the check for now, always write to the server, + // for the same reason as in ReadFrom. + return u.currentConn.WriteTo(b, u.Addrs[u.addrIndex]) +} + +func (u *udpHopPacketConn) Close() error { + u.connMutex.Lock() + defer u.connMutex.Unlock() + if u.closed { + return nil + } + // Close prevConn and currentConn + // Close closeChan to unblock ReadFrom & hopLoop + // Set closed flag to true to prevent double close + if u.prevConn != nil { + _ = u.prevConn.Close() + } + err := u.currentConn.Close() + close(u.closeChan) + u.closed = true + u.Addrs = nil // For GC + return err +} + +func (u *udpHopPacketConn) LocalAddr() net.Addr { + u.connMutex.RLock() + defer u.connMutex.RUnlock() + return u.currentConn.LocalAddr() +} + +func (u *udpHopPacketConn) SetDeadline(t time.Time) error { + u.connMutex.RLock() + defer u.connMutex.RUnlock() + if u.prevConn != nil { + _ = u.prevConn.SetDeadline(t) + } + return u.currentConn.SetDeadline(t) +} + +func (u *udpHopPacketConn) SetReadDeadline(t time.Time) error { + u.connMutex.RLock() + defer u.connMutex.RUnlock() + if u.prevConn != nil { + _ = u.prevConn.SetReadDeadline(t) + } + return u.currentConn.SetReadDeadline(t) +} + +func (u *udpHopPacketConn) SetWriteDeadline(t time.Time) error { + u.connMutex.RLock() + defer u.connMutex.RUnlock() + if u.prevConn != nil { + _ = u.prevConn.SetWriteDeadline(t) + } + return u.currentConn.SetWriteDeadline(t) +} + +// UDP-specific methods below + +func (u *udpHopPacketConn) SetReadBuffer(bytes int) error { + u.connMutex.Lock() + defer u.connMutex.Unlock() + u.readBufferSize = bytes + if u.prevConn != nil { + _ = trySetReadBuffer(u.prevConn, bytes) + } + return trySetReadBuffer(u.currentConn, bytes) +} + +func (u *udpHopPacketConn) SetWriteBuffer(bytes int) error { + u.connMutex.Lock() + defer u.connMutex.Unlock() + u.writeBufferSize = bytes + if u.prevConn != nil { + _ = trySetWriteBuffer(u.prevConn, bytes) + } + return trySetWriteBuffer(u.currentConn, bytes) +} + +func (u *udpHopPacketConn) SyscallConn() (syscall.RawConn, error) { + u.connMutex.RLock() + defer u.connMutex.RUnlock() + sc, ok := u.currentConn.(syscall.Conn) + if !ok { + return nil, errors.New("not supported") + } + return sc.SyscallConn() +} + +func trySetReadBuffer(pc net.PacketConn, bytes int) error { + sc, ok := pc.(interface { + SetReadBuffer(bytes int) error + }) + if ok { + return sc.SetReadBuffer(bytes) + } + return nil +} + +func trySetWriteBuffer(pc net.PacketConn, bytes int) error { + sc, ok := pc.(interface { + SetWriteBuffer(bytes int) error + }) + if ok { + return sc.SetWriteBuffer(bytes) + } + return nil +}