support port hopping for hysteria 2

This commit is contained in:
keakon 2024-01-17 20:00:01 +08:00
parent 460cc240b0
commit 2a0139e236
36 changed files with 5282 additions and 101 deletions

View file

@ -2,137 +2,123 @@ package outbound
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"runtime"
"strconv"
"time"
CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
tuicCommon "github.com/metacubex/mihomo/transport/tuic/common"
"github.com/metacubex/sing-quic/hysteria2"
M "github.com/sagernet/sing/common/metadata"
"github.com/metacubex/mihomo/transport/hysteria2/app/cmd"
hy2client "github.com/metacubex/mihomo/transport/hysteria2/core/client"
)
func init() {
hysteria2.SetCongestionController = tuicCommon.SetCongestionController
}
const minHopInterval = 5
const defaultHopInterval = 30
type Hysteria2 struct {
*Base
option *Hysteria2Option
client *hysteria2.Client
dialer proxydialer.SingDialer
client hy2client.Client
}
type Hysteria2Option struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
Up string `proxy:"up,omitempty"`
Down string `proxy:"down,omitempty"`
Password string `proxy:"password,omitempty"`
Obfs string `proxy:"obfs,omitempty"`
ObfsPassword string `proxy:"obfs-password,omitempty"`
SNI string `proxy:"sni,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"`
ALPN []string `proxy:"alpn,omitempty"`
CustomCA string `proxy:"ca,omitempty"`
CustomCAString string `proxy:"ca-str,omitempty"`
CWND int `proxy:"cwnd,omitempty"`
Name string `proxy:"name"`
Server string `proxy:"server"`
Port uint16 `proxy:"port,omitempty"`
Ports string `proxy:"ports,omitempty"`
HopInterval time.Duration `proxy:"hop-interval,omitempty"`
Up string `proxy:"up"`
Down string `proxy:"down"`
Password string `proxy:"password,omitempty"`
Obfs string `proxy:"obfs,omitempty"`
ObfsPassword string `proxy:"obfs-password,omitempty"`
SNI string `proxy:"sni,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"`
ALPN []string `proxy:"alpn,omitempty"`
CustomCA string `proxy:"ca,omitempty"`
CustomCAString string `proxy:"ca-str,omitempty"`
CWND int `proxy:"cwnd,omitempty"`
FastOpen bool `proxy:"fast-open,omitempty"`
Lazy bool `proxy:"lazy,omitempty"`
}
func (h *Hysteria2) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
options := h.Base.DialOptions(opts...)
h.dialer.SetDialer(dialer.NewDialer(options...))
c, err := h.client.DialConn(ctx, M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort))
tcpConn, err := h.client.TCP(net.JoinHostPort(metadata.String(), strconv.Itoa(int(metadata.DstPort))))
if err != nil {
return nil, err
}
return NewConn(CN.NewRefConn(c, h), h), nil
return NewConn(tcpConn, h), nil
}
func (h *Hysteria2) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
options := h.Base.DialOptions(opts...)
h.dialer.SetDialer(dialer.NewDialer(options...))
pc, err := h.client.ListenPacket(ctx)
udpConn, err := h.client.UDP()
if err != nil {
return nil, err
}
if pc == nil {
return nil, errors.New("packetConn is nil")
}
return newPacketConn(CN.NewRefPacketConn(CN.NewThreadSafePacketConn(pc), h), h), nil
}
func closeHysteria2(h *Hysteria2) {
if h.client != nil {
_ = h.client.CloseWithError(errors.New("proxy removed"))
}
return newPacketConn(udpConn, h), nil
}
func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
var salamanderPassword string
if len(option.Obfs) > 0 {
if option.ObfsPassword == "" {
return nil, errors.New("missing obfs password")
}
switch option.Obfs {
case hysteria2.ObfsTypeSalamander:
salamanderPassword = option.ObfsPassword
default:
return nil, fmt.Errorf("unknown obfs type: %s", option.Obfs)
}
var server string
if option.Ports != "" {
server = net.JoinHostPort(option.Server, option.Ports)
} else {
server = net.JoinHostPort(option.Server, strconv.Itoa(int(option.Port)))
}
serverName := option.Server
if option.SNI != "" {
serverName = option.SNI
if option.HopInterval == 0 {
option.HopInterval = defaultHopInterval
} else if option.HopInterval < minHopInterval {
option.HopInterval = minHopInterval
}
option.HopInterval *= time.Second
config := cmd.ClientConfig{
Server: server,
Auth: option.Password,
Transport: cmd.ClientConfigTransport{
UDP: cmd.ClientConfigTransportUDP{
HopInterval: option.HopInterval,
},
},
TLS: cmd.ClientConfigTLS{
SNI: option.SNI,
Insecure: option.SkipCertVerify,
PinSHA256: option.Fingerprint,
CA: option.CustomCA,
CAString: option.CustomCAString,
},
FastOpen: option.FastOpen,
Lazy: option.Lazy,
}
tlsConfig := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: option.SkipCertVerify,
MinVersion: tls.VersionTLS13,
if option.ObfsPassword != "" {
config.Obfs.Type = "salamander"
config.Obfs.Salamander.Password = option.ObfsPassword
} else if option.Obfs != "" {
config.Obfs.Type = "salamander"
config.Obfs.Salamander.Password = option.Obfs
}
var err error
tlsConfig, err = ca.GetTLSConfig(tlsConfig, option.Fingerprint, option.CustomCA, option.CustomCAString)
if err != nil {
return nil, err
last := option.Up[len(option.Up)-1]
if '0' <= last && last <= '9' {
option.Up += "m"
}
if len(option.ALPN) > 0 {
tlsConfig.NextProtos = option.ALPN
config.Bandwidth.Up = option.Up
last = option.Down[len(option.Down)-1]
if '0' <= last && last <= '9' {
option.Down += "m"
}
config.Bandwidth.Down = option.Down
singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer())
clientOptions := hysteria2.ClientOptions{
Context: context.TODO(),
Dialer: singDialer,
ServerAddress: M.ParseSocksaddrHostPort(option.Server, uint16(option.Port)),
SendBPS: StringToBps(option.Up),
ReceiveBPS: StringToBps(option.Down),
SalamanderPassword: salamanderPassword,
Password: option.Password,
TLSConfig: tlsConfig,
UDPDisabled: false,
CWND: option.CWND,
}
client, err := hysteria2.NewClient(clientOptions)
client, err := hy2client.NewReconnectableClient(
config.Config,
func(c hy2client.Client, info *hy2client.HandshakeInfo, count int) {},
option.Lazy)
if err != nil {
return nil, err
}
@ -140,7 +126,7 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
outbound := &Hysteria2{
Base: &Base{
name: option.Name,
addr: addr,
addr: server,
tp: C.Hysteria2,
udp: true,
iface: option.Interface,
@ -149,9 +135,7 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
},
option: &option,
client: client,
dialer: singDialer,
}
runtime.SetFinalizer(outbound, closeHysteria2)
return outbound, nil
}

8
go.mod
View file

@ -1,6 +1,6 @@
module github.com/metacubex/mihomo
go 1.20
go 1.21
require (
github.com/3andne/restls-client-go v0.1.6
@ -65,11 +65,12 @@ require (
github.com/andybalholm/brotli v1.0.6 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/cloudflare/circl v1.3.6 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/ericlagergren/aegis v0.0.0-20230312195928-b4ce538b56f9 // indirect
github.com/ericlagergren/polyval v0.0.0-20220411101811-e25bc10ba391 // indirect
github.com/ericlagergren/siv v0.0.0-20220507050439-0b757b3aa5f1 // indirect
github.com/ericlagergren/subtle v0.0.0-20220507045147-890d697da010 // indirect
github.com/frankban/quicktest v1.14.6 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gaukas/godicttls v0.0.4 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
@ -89,7 +90,7 @@ require (
github.com/oasisprotocol/deoxysii v0.0.0-20220228165953-2091330c22b7 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/pierrec/lz4/v4 v4.1.14 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/quic-go/qpack v0.4.0 // indirect
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
@ -110,6 +111,7 @@ require (
golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.16.0 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
)
replace github.com/sagernet/sing => github.com/metacubex/sing v0.0.0-20240111014253-f1818b6a82b2

17
go.sum
View file

@ -25,9 +25,11 @@ github.com/cloudflare/circl v1.3.6 h1:/xbKIqSHbZXHwkhbrhrt2YOHIwYJlXH94E3tI/gDlU
github.com/cloudflare/circl v1.3.6/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA=
github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8=
github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/ericlagergren/aegis v0.0.0-20230312195928-b4ce538b56f9 h1:/5RkVc9Rc81XmMyVqawCiDyrBHZbLAZgTTCqou4mwj8=
@ -39,7 +41,8 @@ github.com/ericlagergren/siv v0.0.0-20220507050439-0b757b3aa5f1 h1:tlDMEdcPRQKBE
github.com/ericlagergren/siv v0.0.0-20220507050439-0b757b3aa5f1/go.mod h1:4RfsapbGx2j/vU5xC/5/9qB3kn9Awp1YDiEnN43QrJ4=
github.com/ericlagergren/subtle v0.0.0-20220507045147-890d697da010 h1:fuGucgPk5dN6wzfnxl3D0D3rVLw4v2SbBT9jb4VnxzA=
github.com/ericlagergren/subtle v0.0.0-20220507045147-890d697da010/go.mod h1:JtBcj7sBuTTRupn7c2bFspMDIObMJsVK8TeUvpShPok=
github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gaukas/godicttls v0.0.4 h1:NlRaXb3J6hAnTmWdsEKb9bcSBD6BvcIjdGdeb0zfXbk=
@ -91,7 +94,9 @@ github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6K
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 h1:EnfXoSqDfSNJv0VBNqY/88RNnhSGYkrHaO0mmFGbVsc=
@ -140,9 +145,11 @@ github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq5
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
github.com/pierrec/lz4/v4 v4.1.14 h1:+fL8AQEZtz/ijeNnpduH0bROTu0O3NZAlPjQxGn8LwE=
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
@ -153,6 +160,7 @@ github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1
github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5nfFs=
github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a h1:+NkI2670SQpQWvkkD2QgdTuzQG263YZ+2emfpeyGqW0=
github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a/go.mod h1:63s7jpZqcDAIpj8oI/1v4Izok+npJOHACFCU6+huCkM=
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE=
@ -270,7 +278,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I=
google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -0,0 +1 @@
Copied from [hysteria](https://github.com/apernet/hysteria) v2.2.3 with a little changes.

View file

@ -0,0 +1,374 @@
package cmd
import (
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"errors"
"net"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/metacubex/mihomo/transport/hysteria2/app/utils"
"github.com/metacubex/mihomo/transport/hysteria2/core/client"
"github.com/metacubex/mihomo/transport/hysteria2/extras/obfs"
"github.com/metacubex/mihomo/transport/hysteria2/extras/transport/udphop"
)
type ClientConfig struct {
Server string `mapstructure:"server"`
Auth string `mapstructure:"auth"`
Transport ClientConfigTransport `mapstructure:"transport"`
Obfs ClientConfigObfs `mapstructure:"obfs"`
TLS ClientConfigTLS `mapstructure:"tls"`
QUIC clientConfigQUIC `mapstructure:"quic"`
Bandwidth ClientConfigBandwidth `mapstructure:"bandwidth"`
FastOpen bool `mapstructure:"fastOpen"`
Lazy bool `mapstructure:"lazy"`
}
type ClientConfigTransportUDP struct {
HopInterval time.Duration `mapstructure:"hopInterval"`
}
type ClientConfigTransport struct {
Type string `mapstructure:"type"`
UDP ClientConfigTransportUDP `mapstructure:"udp"`
}
type ClientConfigObfsSalamander struct {
Password string `mapstructure:"password"`
}
type ClientConfigObfs struct {
Type string `mapstructure:"type"`
Salamander ClientConfigObfsSalamander `mapstructure:"salamander"`
}
type ClientConfigTLS struct {
SNI string `mapstructure:"sni"`
Insecure bool `mapstructure:"insecure"`
PinSHA256 string `mapstructure:"pinSHA256"`
CA string `mapstructure:"ca"`
CAString string `mapstructure:"ca-str"`
}
type clientConfigQUIC struct {
InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"`
MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"`
InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"`
MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"`
MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"`
KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"`
DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"`
}
type ClientConfigBandwidth struct {
Up string `mapstructure:"up"`
Down string `mapstructure:"down"`
}
func (c *ClientConfig) fillServerAddr(hyConfig *client.Config) error {
if c.Server == "" {
return configError{Field: "server", Err: errors.New("server address is empty")}
}
var addr net.Addr
var err error
host, port, hostPort := parseServerAddrString(c.Server)
if !isPortHoppingPort(port) {
addr, err = net.ResolveUDPAddr("udp", hostPort)
} else {
addr, err = udphop.ResolveUDPHopAddr(hostPort)
}
if err != nil {
return configError{Field: "server", Err: err}
}
hyConfig.ServerAddr = addr
// Special handling for SNI
if c.TLS.SNI == "" {
// Use server hostname as SNI
hyConfig.TLSConfig.ServerName = host
}
return nil
}
// fillConnFactory must be called after fillServerAddr, as we have different logic
// for ConnFactory depending on whether we have a port hopping address.
func (c *ClientConfig) fillConnFactory(hyConfig *client.Config) error {
// Inner PacketConn
var newFunc func(addr net.Addr) (net.PacketConn, error)
switch strings.ToLower(c.Transport.Type) {
case "", "udp":
if hyConfig.ServerAddr.Network() == "udphop" {
hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr)
newFunc = func(addr net.Addr) (net.PacketConn, error) {
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval)
}
} else {
newFunc = func(addr net.Addr) (net.PacketConn, error) {
return net.ListenUDP("udp", nil)
}
}
default:
return configError{Field: "transport.type", Err: errors.New("unsupported transport type")}
}
// Obfuscation
var ob obfs.Obfuscator
var err error
switch strings.ToLower(c.Obfs.Type) {
case "", "plain":
// Keep it nil
case "salamander":
ob, err = obfs.NewSalamanderObfuscator([]byte(c.Obfs.Salamander.Password))
if err != nil {
return configError{Field: "obfs.salamander.password", Err: err}
}
default:
return configError{Field: "obfs.type", Err: errors.New("unsupported obfuscation type")}
}
hyConfig.ConnFactory = &adaptiveConnFactory{
NewFunc: newFunc,
Obfuscator: ob,
}
return nil
}
func (c *ClientConfig) fillAuth(hyConfig *client.Config) error {
hyConfig.Auth = c.Auth
return nil
}
func (c *ClientConfig) fillTLSConfig(hyConfig *client.Config) error {
if c.TLS.SNI != "" {
hyConfig.TLSConfig.ServerName = c.TLS.SNI
}
hyConfig.TLSConfig.InsecureSkipVerify = c.TLS.Insecure
if c.TLS.PinSHA256 != "" {
nHash := normalizeCertHash(c.TLS.PinSHA256)
hyConfig.TLSConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
for _, cert := range rawCerts {
hash := sha256.Sum256(cert)
hashHex := hex.EncodeToString(hash[:])
if hashHex == nHash {
return nil
}
}
// No match
return errors.New("no certificate matches the pinned hash")
}
}
if c.TLS.CAString != "" || c.TLS.CA != "" {
var ca []byte
if c.TLS.CAString != "" {
ca = []byte(c.TLS.CAString)
} else {
var err error
ca, err = os.ReadFile(c.TLS.CA)
if err != nil {
return configError{Field: "tls.ca", Err: err}
}
}
cPool := x509.NewCertPool()
if !cPool.AppendCertsFromPEM(ca) {
return configError{Field: "tls.ca", Err: errors.New("failed to parse CA certificate")}
}
hyConfig.TLSConfig.RootCAs = cPool
}
return nil
}
func (c *ClientConfig) fillQUICConfig(hyConfig *client.Config) error {
hyConfig.QUICConfig = client.QUICConfig{
InitialStreamReceiveWindow: c.QUIC.InitStreamReceiveWindow,
MaxStreamReceiveWindow: c.QUIC.MaxStreamReceiveWindow,
InitialConnectionReceiveWindow: c.QUIC.InitConnectionReceiveWindow,
MaxConnectionReceiveWindow: c.QUIC.MaxConnectionReceiveWindow,
MaxIdleTimeout: c.QUIC.MaxIdleTimeout,
KeepAlivePeriod: c.QUIC.KeepAlivePeriod,
DisablePathMTUDiscovery: c.QUIC.DisablePathMTUDiscovery,
}
return nil
}
func (c *ClientConfig) fillBandwidthConfig(hyConfig *client.Config) error {
// New core now allows users to omit bandwidth values and use built-in congestion control
var err error
if c.Bandwidth.Up != "" {
hyConfig.BandwidthConfig.MaxTx, err = utils.ConvBandwidth(c.Bandwidth.Up)
if err != nil {
return configError{Field: "bandwidth.up", Err: err}
}
}
if c.Bandwidth.Down != "" {
hyConfig.BandwidthConfig.MaxRx, err = utils.ConvBandwidth(c.Bandwidth.Down)
if err != nil {
return configError{Field: "bandwidth.down", Err: err}
}
}
return nil
}
func (c *ClientConfig) fillFastOpen(hyConfig *client.Config) error {
hyConfig.FastOpen = c.FastOpen
return nil
}
// URI generates a URI for sharing the config with others.
// Note that only the bare minimum of information required to
// connect to the server is included in the URI, specifically:
// - server address
// - authentication
// - obfuscation type
// - obfuscation password
// - TLS SNI
// - TLS insecure
// - TLS pinned SHA256 hash (normalized)
func (c *ClientConfig) URI() string {
q := url.Values{}
switch strings.ToLower(c.Obfs.Type) {
case "salamander":
q.Set("obfs", "salamander")
q.Set("obfs-password", c.Obfs.Salamander.Password)
}
if c.TLS.SNI != "" {
q.Set("sni", c.TLS.SNI)
}
if c.TLS.Insecure {
q.Set("insecure", "1")
}
if c.TLS.PinSHA256 != "" {
q.Set("pinSHA256", normalizeCertHash(c.TLS.PinSHA256))
}
var user *url.Userinfo
if c.Auth != "" {
// We need to handle the special case of user:pass pairs
rs := strings.SplitN(c.Auth, ":", 2)
if len(rs) == 2 {
user = url.UserPassword(rs[0], rs[1])
} else {
user = url.User(c.Auth)
}
}
u := url.URL{
Scheme: "hysteria2",
User: user,
Host: c.Server,
Path: "/",
RawQuery: q.Encode(),
}
return u.String()
}
// parseURI tries to parse the server address field as a URI,
// and fills the config with the information contained in the URI.
// Returns whether the server address field is a valid URI.
// This allows a user to use put a URI as the server address and
// omit the fields that are already contained in the URI.
func (c *ClientConfig) parseURI() bool {
u, err := url.Parse(c.Server)
if err != nil {
return false
}
if u.Scheme != "hysteria2" && u.Scheme != "hy2" {
return false
}
if u.User != nil {
c.Auth = u.User.String()
}
c.Server = u.Host
q := u.Query()
if obfsType := q.Get("obfs"); obfsType != "" {
c.Obfs.Type = obfsType
switch strings.ToLower(obfsType) {
case "salamander":
c.Obfs.Salamander.Password = q.Get("obfs-password")
}
}
if sni := q.Get("sni"); sni != "" {
c.TLS.SNI = sni
}
if insecure, err := strconv.ParseBool(q.Get("insecure")); err == nil {
c.TLS.Insecure = insecure
}
if pinSHA256 := q.Get("pinSHA256"); pinSHA256 != "" {
c.TLS.PinSHA256 = pinSHA256
}
return true
}
// Config validates the fields and returns a ready-to-use Hysteria client config
func (c *ClientConfig) Config() (*client.Config, error) {
c.parseURI()
hyConfig := &client.Config{}
fillers := []func(*client.Config) error{
c.fillServerAddr,
c.fillConnFactory,
c.fillAuth,
c.fillTLSConfig,
c.fillQUICConfig,
c.fillBandwidthConfig,
c.fillFastOpen,
}
for _, f := range fillers {
if err := f(hyConfig); err != nil {
return nil, err
}
}
return hyConfig, nil
}
type clientModeRunner struct {
ModeMap map[string]func() error
}
func (r *clientModeRunner) Add(name string, f func() error) {
if r.ModeMap == nil {
r.ModeMap = make(map[string]func() error)
}
r.ModeMap[name] = f
}
// parseServerAddrString parses server address string.
// Server address can be in either "host:port" or "host" format (in which case we assume port 443).
func parseServerAddrString(addrStr string) (host, port, hostPort string) {
h, p, err := net.SplitHostPort(addrStr)
if err != nil {
return addrStr, "443", net.JoinHostPort(addrStr, "443")
}
return h, p, addrStr
}
// isPortHoppingPort returns whether the port string is a port hopping port.
// We consider a port string to be a port hopping port if it contains "-" or ",".
func isPortHoppingPort(port string) bool {
return strings.Contains(port, "-") || strings.Contains(port, ",")
}
// normalizeCertHash normalizes a certificate hash string.
// It converts all characters to lowercase and removes possible separators such as ":" and "-".
func normalizeCertHash(hash string) string {
r := strings.ToLower(hash)
r = strings.ReplaceAll(r, ":", "")
r = strings.ReplaceAll(r, "-", "")
return r
}
type adaptiveConnFactory struct {
NewFunc func(addr net.Addr) (net.PacketConn, error)
Obfuscator obfs.Obfuscator // nil if no obfuscation
}
func (f *adaptiveConnFactory) New(addr net.Addr) (net.PacketConn, error) {
if f.Obfuscator == nil {
return f.NewFunc(addr)
} else {
conn, err := f.NewFunc(addr)
if err != nil {
return nil, err
}
return obfs.WrapPacketConn(conn, f.Obfuscator), nil
}
}

View file

@ -0,0 +1,18 @@
package cmd
import (
"fmt"
)
type configError struct {
Field string
Err error
}
func (e configError) Error() string {
return fmt.Sprintf("invalid config: %s: %s", e.Field, e.Err)
}
func (e configError) Unwrap() error {
return e.Err
}

View file

@ -0,0 +1,68 @@
package utils
import (
"errors"
"fmt"
"strconv"
"strings"
)
const (
Byte = 1
Kilobyte = Byte * 1000
Megabyte = Kilobyte * 1000
Gigabyte = Megabyte * 1000
Terabyte = Gigabyte * 1000
)
// StringToBps converts a string to a bandwidth value in bytes per second.
// E.g. "100 Mbps", "512 kbps", "1g" are all valid.
func StringToBps(s string) (uint64, error) {
s = strings.ToLower(strings.TrimSpace(s))
spl := 0
for i, c := range s {
if c < '0' || c > '9' {
spl = i
break
}
}
if spl == 0 {
// No unit or no value
return 0, errors.New("invalid format")
}
v, err := strconv.ParseUint(s[:spl], 10, 64)
if err != nil {
return 0, err
}
unit := strings.TrimSpace(s[spl:])
switch strings.ToLower(unit) {
case "b", "bps":
return v * Byte / 8, nil
case "k", "kb", "kbps":
return v * Kilobyte / 8, nil
case "m", "mb", "mbps":
return v * Megabyte / 8, nil
case "g", "gb", "gbps":
return v * Gigabyte / 8, nil
case "t", "tb", "tbps":
return v * Terabyte / 8, nil
default:
return 0, errors.New("unsupported unit")
}
}
// ConvBandwidth handles both string and int types for bandwidth.
// When using string, it will be parsed as a bandwidth string with units.
// When using int, it will be parsed as a raw bandwidth in bytes per second.
// It does NOT support float types.
func ConvBandwidth(bw interface{}) (uint64, error) {
switch bwT := bw.(type) {
case string:
return StringToBps(bwT)
case int:
return uint64(bwT), nil
default:
return 0, fmt.Errorf("invalid type %T for bandwidth", bwT)
}
}

View file

@ -0,0 +1,316 @@
package client
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"time"
coreErrs "github.com/metacubex/mihomo/transport/hysteria2/core/errors"
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion"
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/protocol"
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/utils"
"github.com/metacubex/quic-go"
"github.com/metacubex/quic-go/http3"
)
const (
closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError
closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError
)
type Client interface {
TCP(addr string) (net.Conn, error)
UDP() (HyUDPConn, error)
Close() error
}
type HyUDPConn interface {
Receive() ([]byte, string, error)
Send([]byte, string) error
net.PacketConn
}
type HandshakeInfo struct {
UDPEnabled bool
Tx uint64 // 0 if using BBR
}
func NewClient(config *Config) (Client, *HandshakeInfo, error) {
if err := config.verifyAndFill(); err != nil {
return nil, nil, err
}
c := &clientImpl{
config: config,
}
info, err := c.connect()
if err != nil {
return nil, nil, err
}
return c, info, nil
}
type clientImpl struct {
config *Config
pktConn net.PacketConn
conn quic.Connection
udpSM *udpSessionManager
}
func (c *clientImpl) connect() (*HandshakeInfo, error) {
pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr)
if err != nil {
return nil, err
}
// Convert config to TLS config & QUIC config
tlsConfig := &tls.Config{
ServerName: c.config.TLSConfig.ServerName,
InsecureSkipVerify: c.config.TLSConfig.InsecureSkipVerify,
VerifyPeerCertificate: c.config.TLSConfig.VerifyPeerCertificate,
RootCAs: c.config.TLSConfig.RootCAs,
}
quicConfig := &quic.Config{
InitialStreamReceiveWindow: c.config.QUICConfig.InitialStreamReceiveWindow,
MaxStreamReceiveWindow: c.config.QUICConfig.MaxStreamReceiveWindow,
InitialConnectionReceiveWindow: c.config.QUICConfig.InitialConnectionReceiveWindow,
MaxConnectionReceiveWindow: c.config.QUICConfig.MaxConnectionReceiveWindow,
MaxIdleTimeout: c.config.QUICConfig.MaxIdleTimeout,
KeepAlivePeriod: c.config.QUICConfig.KeepAlivePeriod,
DisablePathMTUDiscovery: c.config.QUICConfig.DisablePathMTUDiscovery,
EnableDatagrams: true,
}
// Prepare RoundTripper
var conn quic.EarlyConnection
rt := &http3.RoundTripper{
EnableDatagrams: true,
TLSClientConfig: tlsConfig,
QuicConfig: quicConfig,
Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
qc, err := quic.DialEarly(ctx, pktConn, c.config.ServerAddr, tlsCfg, cfg)
if err != nil {
return nil, err
}
conn = qc
return qc, nil
},
}
// Send auth HTTP request
req := &http.Request{
Method: http.MethodPost,
URL: &url.URL{
Scheme: "https",
Host: protocol.URLHost,
Path: protocol.URLPath,
},
Header: make(http.Header),
}
protocol.AuthRequestToHeader(req.Header, protocol.AuthRequest{
Auth: c.config.Auth,
Rx: c.config.BandwidthConfig.MaxRx,
})
resp, err := rt.RoundTrip(req)
if err != nil {
if conn != nil {
_ = conn.CloseWithError(closeErrCodeProtocolError, "")
}
_ = pktConn.Close()
return nil, coreErrs.ConnectError{Err: err}
}
if resp.StatusCode != protocol.StatusAuthOK {
_ = conn.CloseWithError(closeErrCodeProtocolError, "")
_ = pktConn.Close()
return nil, coreErrs.AuthError{StatusCode: resp.StatusCode}
}
// Auth OK
authResp := protocol.AuthResponseFromHeader(resp.Header)
var actualTx uint64
if authResp.RxAuto {
// Server asks client to use bandwidth detection,
// ignore local bandwidth config and use BBR
congestion.UseBBR(conn)
} else {
// actualTx = min(serverRx, clientTx)
actualTx = authResp.Rx
if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx {
// Server doesn't have a limit, or our clientTx is smaller than serverRx
actualTx = c.config.BandwidthConfig.MaxTx
}
if actualTx > 0 {
congestion.UseBrutal(conn, actualTx)
} else {
// We don't know our own bandwidth either, use BBR
congestion.UseBBR(conn)
}
}
_ = resp.Body.Close()
c.pktConn = pktConn
c.conn = conn
if authResp.UDPEnabled {
c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn})
}
return &HandshakeInfo{
UDPEnabled: authResp.UDPEnabled,
Tx: actualTx,
}, nil
}
// openStream wraps the stream with QStream, which handles Close() properly
func (c *clientImpl) openStream() (quic.Stream, error) {
stream, err := c.conn.OpenStream()
if err != nil {
return nil, err
}
return &utils.QStream{Stream: stream}, nil
}
func (c *clientImpl) TCP(addr string) (net.Conn, error) {
stream, err := c.openStream()
if err != nil {
return nil, wrapIfConnectionClosed(err)
}
// Send request
err = protocol.WriteTCPRequest(stream, addr)
if err != nil {
_ = stream.Close()
return nil, wrapIfConnectionClosed(err)
}
if c.config.FastOpen {
// Don't wait for the response when fast open is enabled.
// Return the connection immediately, defer the response handling
// to the first Read() call.
return &tcpConn{
Orig: stream,
PseudoLocalAddr: c.conn.LocalAddr(),
PseudoRemoteAddr: c.conn.RemoteAddr(),
Established: false,
}, nil
}
// Read response
ok, msg, err := protocol.ReadTCPResponse(stream)
if err != nil {
_ = stream.Close()
return nil, wrapIfConnectionClosed(err)
}
if !ok {
_ = stream.Close()
return nil, coreErrs.DialError{Message: msg}
}
return &tcpConn{
Orig: stream,
PseudoLocalAddr: c.conn.LocalAddr(),
PseudoRemoteAddr: c.conn.RemoteAddr(),
Established: true,
}, nil
}
func (c *clientImpl) UDP() (HyUDPConn, error) {
if c.udpSM == nil {
return nil, coreErrs.DialError{Message: "UDP not enabled"}
}
return c.udpSM.NewUDP()
}
func (c *clientImpl) Close() error {
_ = c.conn.CloseWithError(closeErrCodeOK, "")
_ = c.pktConn.Close()
return nil
}
// wrapIfConnectionClosed checks if the error returned by quic-go
// indicates that the QUIC connection has been permanently closed,
// and if so, wraps the error with coreErrs.ClosedError.
// PITFALL: sometimes quic-go has "internal errors" that are not net.Error,
// but we still need to treat them as ClosedError.
func wrapIfConnectionClosed(err error) error {
netErr, ok := err.(net.Error)
if !ok || !netErr.Temporary() {
return coreErrs.ClosedError{Err: err}
} else {
return err
}
}
type tcpConn struct {
Orig quic.Stream
PseudoLocalAddr net.Addr
PseudoRemoteAddr net.Addr
Established bool
}
func (c *tcpConn) Read(b []byte) (n int, err error) {
if !c.Established {
// Read response
ok, msg, err := protocol.ReadTCPResponse(c.Orig)
if err != nil {
return 0, err
}
if !ok {
return 0, coreErrs.DialError{Message: msg}
}
c.Established = true
}
return c.Orig.Read(b)
}
func (c *tcpConn) Write(b []byte) (n int, err error) {
return c.Orig.Write(b)
}
func (c *tcpConn) Close() error {
return c.Orig.Close()
}
func (c *tcpConn) LocalAddr() net.Addr {
return c.PseudoLocalAddr
}
func (c *tcpConn) RemoteAddr() net.Addr {
return c.PseudoRemoteAddr
}
func (c *tcpConn) SetDeadline(t time.Time) error {
return c.Orig.SetDeadline(t)
}
func (c *tcpConn) SetReadDeadline(t time.Time) error {
return c.Orig.SetReadDeadline(t)
}
func (c *tcpConn) SetWriteDeadline(t time.Time) error {
return c.Orig.SetWriteDeadline(t)
}
type udpIOImpl struct {
Conn quic.Connection
}
func (io *udpIOImpl) ReceiveMessage() (*protocol.UDPMessage, error) {
for {
msg, err := io.Conn.ReceiveDatagram(context.Background())
if err != nil {
// Connection error, this will stop the session manager
return nil, err
}
udpMsg, err := protocol.ParseUDPMessage(msg)
if err != nil {
// Invalid message, this is fine - just wait for the next
continue
}
return udpMsg, nil
}
}
func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
msgN := msg.Serialize(buf)
if msgN < 0 {
// Message larger than buffer, silent drop
return nil
}
return io.Conn.SendDatagram(buf[:msgN])
}

View file

@ -0,0 +1,112 @@
package client
import (
"crypto/x509"
"net"
"time"
"github.com/metacubex/mihomo/transport/hysteria2/core/errors"
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/pmtud"
)
const (
defaultStreamReceiveWindow = 8388608 // 8MB
defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB
defaultMaxIdleTimeout = 30 * time.Second
defaultKeepAlivePeriod = 10 * time.Second
)
type Config struct {
ConnFactory ConnFactory
ServerAddr net.Addr
Auth string
TLSConfig TLSConfig
QUICConfig QUICConfig
BandwidthConfig BandwidthConfig
FastOpen bool
filled bool // whether the fields have been verified and filled
}
// verifyAndFill fills the fields that are not set by the user with default values when possible,
// and returns an error if the user has not set a required field or has set an invalid value.
func (c *Config) verifyAndFill() error {
if c.filled {
return nil
}
if c.ConnFactory == nil {
c.ConnFactory = &udpConnFactory{}
}
if c.ServerAddr == nil {
return errors.ConfigError{Field: "ServerAddr", Reason: "must be set"}
}
if c.QUICConfig.InitialStreamReceiveWindow == 0 {
c.QUICConfig.InitialStreamReceiveWindow = defaultStreamReceiveWindow
} else if c.QUICConfig.InitialStreamReceiveWindow < 16384 {
return errors.ConfigError{Field: "QUICConfig.InitialStreamReceiveWindow", Reason: "must be at least 16384"}
}
if c.QUICConfig.MaxStreamReceiveWindow == 0 {
c.QUICConfig.MaxStreamReceiveWindow = defaultStreamReceiveWindow
} else if c.QUICConfig.MaxStreamReceiveWindow < 16384 {
return errors.ConfigError{Field: "QUICConfig.MaxStreamReceiveWindow", Reason: "must be at least 16384"}
}
if c.QUICConfig.InitialConnectionReceiveWindow == 0 {
c.QUICConfig.InitialConnectionReceiveWindow = defaultConnReceiveWindow
} else if c.QUICConfig.InitialConnectionReceiveWindow < 16384 {
return errors.ConfigError{Field: "QUICConfig.InitialConnectionReceiveWindow", Reason: "must be at least 16384"}
}
if c.QUICConfig.MaxConnectionReceiveWindow == 0 {
c.QUICConfig.MaxConnectionReceiveWindow = defaultConnReceiveWindow
} else if c.QUICConfig.MaxConnectionReceiveWindow < 16384 {
return errors.ConfigError{Field: "QUICConfig.MaxConnectionReceiveWindow", Reason: "must be at least 16384"}
}
if c.QUICConfig.MaxIdleTimeout == 0 {
c.QUICConfig.MaxIdleTimeout = defaultMaxIdleTimeout
} else if c.QUICConfig.MaxIdleTimeout < 4*time.Second || c.QUICConfig.MaxIdleTimeout > 120*time.Second {
return errors.ConfigError{Field: "QUICConfig.MaxIdleTimeout", Reason: "must be between 4s and 120s"}
}
if c.QUICConfig.KeepAlivePeriod == 0 {
c.QUICConfig.KeepAlivePeriod = defaultKeepAlivePeriod
} else if c.QUICConfig.KeepAlivePeriod < 2*time.Second || c.QUICConfig.KeepAlivePeriod > 60*time.Second {
return errors.ConfigError{Field: "QUICConfig.KeepAlivePeriod", Reason: "must be between 2s and 60s"}
}
c.QUICConfig.DisablePathMTUDiscovery = c.QUICConfig.DisablePathMTUDiscovery || pmtud.DisablePathMTUDiscovery
c.filled = true
return nil
}
type ConnFactory interface {
New(net.Addr) (net.PacketConn, error)
}
type udpConnFactory struct{}
func (f *udpConnFactory) New(addr net.Addr) (net.PacketConn, error) {
return net.ListenUDP("udp", nil)
}
// TLSConfig contains the TLS configuration fields that we want to expose to the user.
type TLSConfig struct {
ServerName string
InsecureSkipVerify bool
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
RootCAs *x509.CertPool
}
// QUICConfig contains the QUIC configuration fields that we want to expose to the user.
type QUICConfig struct {
InitialStreamReceiveWindow uint64
MaxStreamReceiveWindow uint64
InitialConnectionReceiveWindow uint64
MaxConnectionReceiveWindow uint64
MaxIdleTimeout time.Duration
KeepAlivePeriod time.Duration
DisablePathMTUDiscovery bool // The server may still override this to true on unsupported platforms.
}
// BandwidthConfig describes the maximum bandwidth that the server can use, in bytes per second.
type BandwidthConfig struct {
MaxTx uint64
MaxRx uint64
}

View file

@ -0,0 +1,117 @@
package client
import (
"net"
"sync"
coreErrs "github.com/metacubex/mihomo/transport/hysteria2/core/errors"
)
// reconnectableClientImpl is a wrapper of Client, which can reconnect when the connection is closed,
// except when the caller explicitly calls Close() to permanently close this client.
type reconnectableClientImpl struct {
configFunc func() (*Config, error) // called before connecting
connectedFunc func(Client, *HandshakeInfo, int) // called when successfully connected
client Client
count int
m sync.Mutex
closed bool // permanent close
}
// NewReconnectableClient creates a reconnectable client.
// If lazy is true, the client will not connect until the first call to TCP() or UDP().
// We use a function for config mainly to delay config evaluation
// (which involves DNS resolution) until the actual connection attempt.
func NewReconnectableClient(configFunc func() (*Config, error), connectedFunc func(Client, *HandshakeInfo, int), lazy bool) (Client, error) {
rc := &reconnectableClientImpl{
configFunc: configFunc,
connectedFunc: connectedFunc,
}
if !lazy {
if err := rc.reconnect(); err != nil {
return nil, err
}
}
return rc, nil
}
func (rc *reconnectableClientImpl) reconnect() error {
if rc.client != nil {
_ = rc.client.Close()
}
var info *HandshakeInfo
config, err := rc.configFunc()
if err != nil {
return err
}
rc.client, info, err = NewClient(config)
if err != nil {
return err
} else {
rc.count++
if rc.connectedFunc != nil {
rc.connectedFunc(rc, info, rc.count)
}
return nil
}
}
func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) {
rc.m.Lock()
defer rc.m.Unlock()
if rc.closed {
return nil, coreErrs.ClosedError{}
}
if rc.client == nil {
// No active connection, connect first
if err := rc.reconnect(); err != nil {
return nil, err
}
}
conn, err := rc.client.TCP(addr)
if _, ok := err.(coreErrs.ClosedError); ok {
// Connection closed, reconnect
if err := rc.reconnect(); err != nil {
return nil, err
}
return rc.client.TCP(addr)
} else {
// OK or some other temporary error
return conn, err
}
}
func (rc *reconnectableClientImpl) UDP() (HyUDPConn, error) {
rc.m.Lock()
defer rc.m.Unlock()
if rc.closed {
return nil, coreErrs.ClosedError{}
}
if rc.client == nil {
// No active connection, connect first
if err := rc.reconnect(); err != nil {
return nil, err
}
}
conn, err := rc.client.UDP()
if _, ok := err.(coreErrs.ClosedError); ok {
// Connection closed, reconnect
if err := rc.reconnect(); err != nil {
return nil, err
}
return rc.client.UDP()
} else {
// OK or some other temporary error
return conn, err
}
}
func (rc *reconnectableClientImpl) Close() error {
rc.m.Lock()
defer rc.m.Unlock()
rc.closed = true
if rc.client != nil {
return rc.client.Close()
}
return nil
}

View file

@ -0,0 +1,223 @@
package client
import (
"errors"
"io"
"math/rand"
"net"
"sync"
"time"
coreErrs "github.com/metacubex/mihomo/transport/hysteria2/core/errors"
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/frag"
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/protocol"
"github.com/metacubex/quic-go"
M "github.com/sagernet/sing/common/metadata"
)
const (
udpMessageChanSize = 1024
)
type udpIO interface {
ReceiveMessage() (*protocol.UDPMessage, error)
SendMessage([]byte, *protocol.UDPMessage) error
}
type udpConn struct {
ID uint32
D *frag.Defragger
ReceiveCh chan *protocol.UDPMessage
SendBuf []byte
SendFunc func([]byte, *protocol.UDPMessage) error
CloseFunc func()
Closed bool
}
func (u *udpConn) Receive() ([]byte, string, error) {
for {
msg := <-u.ReceiveCh
if msg == nil {
// Closed
return nil, "", io.EOF
}
dfMsg := u.D.Feed(msg)
if dfMsg == nil {
// Incomplete message, wait for more
continue
}
return dfMsg.Data, dfMsg.Addr, nil
}
}
func (u *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
bs, addrStr, err := u.Receive()
n = copy(p, bs)
addr = M.ParseSocksaddr(addrStr).UDPAddr()
return
}
// Send is not thread-safe, as it uses a shared SendBuf.
func (u *udpConn) Send(data []byte, addr string) error {
// Try no frag first
msg := &protocol.UDPMessage{
SessionID: u.ID,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: addr,
Data: data,
}
err := u.SendFunc(u.SendBuf, msg)
var errTooLarge quic.ErrMessageTooLarge
if errors.As(err, &errTooLarge) {
// Message too large, try fragmentation
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
for _, fMsg := range fMsgs {
err := u.SendFunc(u.SendBuf, &fMsg)
if err != nil {
return err
}
}
return nil
} else {
return err
}
}
func (u *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
err = u.Send(p, M.SocksaddrFromNet(addr).String())
if err != nil {
return
}
n = len(p)
return
}
func (u *udpConn) Close() error {
u.CloseFunc()
return nil
}
func (u *udpConn) LocalAddr() net.Addr {
// a fake implementation to satisfy net.PacketConn
return nil
}
func (u *udpConn) SetDeadline(t time.Time) error {
// a fake implementation to satisfy net.PacketConn
return nil
}
func (u *udpConn) SetReadDeadline(t time.Time) error {
// a fake implementation to satisfy net.PacketConn
return nil
}
func (u *udpConn) SetWriteDeadline(t time.Time) error {
// a fake implementation to satisfy net.PacketConn
return nil
}
type udpSessionManager struct {
io udpIO
mutex sync.RWMutex
m map[uint32]*udpConn
nextID uint32
closed bool
}
func newUDPSessionManager(io udpIO) *udpSessionManager {
m := &udpSessionManager{
io: io,
m: make(map[uint32]*udpConn),
nextID: 1,
}
go m.run()
return m
}
func (m *udpSessionManager) run() error {
defer m.closeCleanup()
for {
msg, err := m.io.ReceiveMessage()
if err != nil {
return err
}
m.feed(msg)
}
}
func (m *udpSessionManager) closeCleanup() {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, conn := range m.m {
m.close(conn)
}
m.closed = true
}
func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
m.mutex.RLock()
defer m.mutex.RUnlock()
conn, ok := m.m[msg.SessionID]
if !ok {
// Ignore message from unknown session
return
}
select {
case conn.ReceiveCh <- msg:
// OK
default:
// Channel full, drop the message
}
}
// NewUDP creates a new UDP session.
func (m *udpSessionManager) NewUDP() (HyUDPConn, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.closed {
return nil, coreErrs.ClosedError{}
}
id := m.nextID
m.nextID++
conn := &udpConn{
ID: id,
D: &frag.Defragger{},
ReceiveCh: make(chan *protocol.UDPMessage, udpMessageChanSize),
SendBuf: make([]byte, protocol.MaxUDPSize),
SendFunc: m.io.SendMessage,
}
conn.CloseFunc = func() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.close(conn)
}
m.m[id] = conn
return conn, nil
}
func (m *udpSessionManager) close(conn *udpConn) {
if !conn.Closed {
conn.Closed = true
close(conn.ReceiveCh)
delete(m.m, conn.ID)
}
}
func (m *udpSessionManager) Count() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.m)
}

View file

@ -0,0 +1,75 @@
package errors
import (
"fmt"
"strconv"
)
// ConfigError is returned when a configuration field is invalid.
type ConfigError struct {
Field string
Reason string
}
func (c ConfigError) Error() string {
return fmt.Sprintf("invalid config: %s: %s", c.Field, c.Reason)
}
// ConnectError is returned when the client fails to connect to the server.
type ConnectError struct {
Err error
}
func (c ConnectError) Error() string {
return "connect error: " + c.Err.Error()
}
func (c ConnectError) Unwrap() error {
return c.Err
}
// AuthError is returned when the client fails to authenticate with the server.
type AuthError struct {
StatusCode int
}
func (a AuthError) Error() string {
return "authentication error, HTTP status code: " + strconv.Itoa(a.StatusCode)
}
// DialError is returned when the server rejects the client's dial request.
// This applies to both TCP and UDP.
type DialError struct {
Message string
}
func (c DialError) Error() string {
return "dial error: " + c.Message
}
// ClosedError is returned when the client attempts to use a closed connection.
type ClosedError struct {
Err error // Can be nil
}
func (c ClosedError) Error() string {
if c.Err == nil {
return "connection closed"
} else {
return "connection closed: " + c.Err.Error()
}
}
func (c ClosedError) Unwrap() error {
return c.Err
}
// ProtocolError is returned when the server/client runs into an unexpected
// or malformed request/response/message.
type ProtocolError struct {
Message string
}
func (p ProtocolError) Error() string {
return "protocol error: " + p.Message
}

View file

@ -0,0 +1,27 @@
package bbr
import (
"math"
"time"
"github.com/metacubex/quic-go/congestion"
)
const (
infBandwidth = Bandwidth(math.MaxUint64)
)
// Bandwidth of a connection
type Bandwidth uint64
const (
// BitsPerSecond is 1 bit per second
BitsPerSecond Bandwidth = 1
// BytesPerSecond is 1 byte per second
BytesPerSecond = 8 * BitsPerSecond
)
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth {
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
}

View file

@ -0,0 +1,874 @@
package bbr
import (
"math"
"time"
"github.com/metacubex/quic-go/congestion"
)
const (
infRTT = time.Duration(math.MaxInt64)
defaultConnectionStateMapQueueSize = 256
defaultCandidatesBufferSize = 256
)
type roundTripCount uint64
// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned
// to the caller when the packet is acked or lost.
type sendTimeState struct {
// Whether other states in this object is valid.
isValid bool
// Whether the sender is app limited at the time the packet was sent.
// App limited bandwidth sample might be artificially low because the sender
// did not have enough data to send in order to saturate the link.
isAppLimited bool
// Total number of sent bytes at the time the packet was sent.
// Includes the packet itself.
totalBytesSent congestion.ByteCount
// Total number of acked bytes at the time the packet was sent.
totalBytesAcked congestion.ByteCount
// Total number of lost bytes at the time the packet was sent.
totalBytesLost congestion.ByteCount
// Total number of inflight bytes at the time the packet was sent.
// Includes the packet itself.
// It should be equal to |total_bytes_sent| minus the sum of
// |total_bytes_acked|, |total_bytes_lost| and total neutered bytes.
bytesInFlight congestion.ByteCount
}
func newSendTimeState(
isAppLimited bool,
totalBytesSent congestion.ByteCount,
totalBytesAcked congestion.ByteCount,
totalBytesLost congestion.ByteCount,
bytesInFlight congestion.ByteCount,
) *sendTimeState {
return &sendTimeState{
isValid: true,
isAppLimited: isAppLimited,
totalBytesSent: totalBytesSent,
totalBytesAcked: totalBytesAcked,
totalBytesLost: totalBytesLost,
bytesInFlight: bytesInFlight,
}
}
type extraAckedEvent struct {
// The excess bytes acknowlwedged in the time delta for this event.
extraAcked congestion.ByteCount
// The bytes acknowledged and time delta from the event.
bytesAcked congestion.ByteCount
timeDelta time.Duration
// The round trip of the event.
round roundTripCount
}
func maxExtraAckedEventFunc(a, b extraAckedEvent) int {
if a.extraAcked > b.extraAcked {
return 1
} else if a.extraAcked < b.extraAcked {
return -1
}
return 0
}
// BandwidthSample
type bandwidthSample struct {
// The bandwidth at that particular sample. Zero if no valid bandwidth sample
// is available.
bandwidth Bandwidth
// The RTT measurement at this particular sample. Zero if no RTT sample is
// available. Does not correct for delayed ack time.
rtt time.Duration
// |send_rate| is computed from the current packet being acked('P') and an
// earlier packet that is acked before P was sent.
sendRate Bandwidth
// States captured when the packet was sent.
stateAtSend sendTimeState
}
func newBandwidthSample() *bandwidthSample {
return &bandwidthSample{
sendRate: infBandwidth,
}
}
// MaxAckHeightTracker is part of the BandwidthSampler. It is called after every
// ack event to keep track the degree of ack aggregation(a.k.a "ack height").
type maxAckHeightTracker struct {
// Tracks the maximum number of bytes acked faster than the estimated
// bandwidth.
maxAckHeightFilter *WindowedFilter[extraAckedEvent, roundTripCount]
// The time this aggregation started and the number of bytes acked during it.
aggregationEpochStartTime time.Time
aggregationEpochBytes congestion.ByteCount
// The last sent packet number before the current aggregation epoch started.
lastSentPacketNumberBeforeEpoch congestion.PacketNumber
// The number of ack aggregation epochs ever started, including the ongoing
// one. Stats only.
numAckAggregationEpochs uint64
ackAggregationBandwidthThreshold float64
startNewAggregationEpochAfterFullRound bool
reduceExtraAckedOnBandwidthIncrease bool
}
func newMaxAckHeightTracker(windowLength roundTripCount) *maxAckHeightTracker {
return &maxAckHeightTracker{
maxAckHeightFilter: NewWindowedFilter(windowLength, maxExtraAckedEventFunc),
lastSentPacketNumberBeforeEpoch: invalidPacketNumber,
ackAggregationBandwidthThreshold: 1.0,
}
}
func (m *maxAckHeightTracker) Get() congestion.ByteCount {
return m.maxAckHeightFilter.GetBest().extraAcked
}
func (m *maxAckHeightTracker) Update(
bandwidthEstimate Bandwidth,
isNewMaxBandwidth bool,
roundTripCount roundTripCount,
lastSentPacketNumber congestion.PacketNumber,
lastAckedPacketNumber congestion.PacketNumber,
ackTime time.Time,
bytesAcked congestion.ByteCount,
) congestion.ByteCount {
forceNewEpoch := false
if m.reduceExtraAckedOnBandwidthIncrease && isNewMaxBandwidth {
// Save and clear existing entries.
best := m.maxAckHeightFilter.GetBest()
secondBest := m.maxAckHeightFilter.GetSecondBest()
thirdBest := m.maxAckHeightFilter.GetThirdBest()
m.maxAckHeightFilter.Clear()
// Reinsert the heights into the filter after recalculating.
expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, best.timeDelta)
if expectedBytesAcked < best.bytesAcked {
best.extraAcked = best.bytesAcked - expectedBytesAcked
m.maxAckHeightFilter.Update(best, best.round)
}
expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, secondBest.timeDelta)
if expectedBytesAcked < secondBest.bytesAcked {
secondBest.extraAcked = secondBest.bytesAcked - expectedBytesAcked
m.maxAckHeightFilter.Update(secondBest, secondBest.round)
}
expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, thirdBest.timeDelta)
if expectedBytesAcked < thirdBest.bytesAcked {
thirdBest.extraAcked = thirdBest.bytesAcked - expectedBytesAcked
m.maxAckHeightFilter.Update(thirdBest, thirdBest.round)
}
}
// If any packet sent after the start of the epoch has been acked, start a new
// epoch.
if m.startNewAggregationEpochAfterFullRound &&
m.lastSentPacketNumberBeforeEpoch != invalidPacketNumber &&
lastAckedPacketNumber != invalidPacketNumber &&
lastAckedPacketNumber > m.lastSentPacketNumberBeforeEpoch {
forceNewEpoch = true
}
if m.aggregationEpochStartTime.IsZero() || forceNewEpoch {
m.aggregationEpochBytes = bytesAcked
m.aggregationEpochStartTime = ackTime
m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber
m.numAckAggregationEpochs++
return 0
}
// Compute how many bytes are expected to be delivered, assuming max bandwidth
// is correct.
aggregationDelta := ackTime.Sub(m.aggregationEpochStartTime)
expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, aggregationDelta)
// Reset the current aggregation epoch as soon as the ack arrival rate is less
// than or equal to the max bandwidth.
if m.aggregationEpochBytes <= congestion.ByteCount(m.ackAggregationBandwidthThreshold*float64(expectedBytesAcked)) {
// Reset to start measuring a new aggregation epoch.
m.aggregationEpochBytes = bytesAcked
m.aggregationEpochStartTime = ackTime
m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber
m.numAckAggregationEpochs++
return 0
}
m.aggregationEpochBytes += bytesAcked
// Compute how many extra bytes were delivered vs max bandwidth.
extraBytesAcked := m.aggregationEpochBytes - expectedBytesAcked
newEvent := extraAckedEvent{
extraAcked: expectedBytesAcked,
bytesAcked: m.aggregationEpochBytes,
timeDelta: aggregationDelta,
}
m.maxAckHeightFilter.Update(newEvent, roundTripCount)
return extraBytesAcked
}
func (m *maxAckHeightTracker) SetFilterWindowLength(length roundTripCount) {
m.maxAckHeightFilter.SetWindowLength(length)
}
func (m *maxAckHeightTracker) Reset(newHeight congestion.ByteCount, newTime roundTripCount) {
newEvent := extraAckedEvent{
extraAcked: newHeight,
round: newTime,
}
m.maxAckHeightFilter.Reset(newEvent, newTime)
}
func (m *maxAckHeightTracker) SetAckAggregationBandwidthThreshold(threshold float64) {
m.ackAggregationBandwidthThreshold = threshold
}
func (m *maxAckHeightTracker) SetStartNewAggregationEpochAfterFullRound(value bool) {
m.startNewAggregationEpochAfterFullRound = value
}
func (m *maxAckHeightTracker) SetReduceExtraAckedOnBandwidthIncrease(value bool) {
m.reduceExtraAckedOnBandwidthIncrease = value
}
func (m *maxAckHeightTracker) AckAggregationBandwidthThreshold() float64 {
return m.ackAggregationBandwidthThreshold
}
func (m *maxAckHeightTracker) NumAckAggregationEpochs() uint64 {
return m.numAckAggregationEpochs
}
// AckPoint represents a point on the ack line.
type ackPoint struct {
ackTime time.Time
totalBytesAcked congestion.ByteCount
}
// RecentAckPoints maintains the most recent 2 ack points at distinct times.
type recentAckPoints struct {
ackPoints [2]ackPoint
}
func (r *recentAckPoints) Update(ackTime time.Time, totalBytesAcked congestion.ByteCount) {
if ackTime.Before(r.ackPoints[1].ackTime) {
r.ackPoints[1].ackTime = ackTime
} else if ackTime.After(r.ackPoints[1].ackTime) {
r.ackPoints[0] = r.ackPoints[1]
r.ackPoints[1].ackTime = ackTime
}
r.ackPoints[1].totalBytesAcked = totalBytesAcked
}
func (r *recentAckPoints) Clear() {
r.ackPoints[0] = ackPoint{}
r.ackPoints[1] = ackPoint{}
}
func (r *recentAckPoints) MostRecentPoint() *ackPoint {
return &r.ackPoints[1]
}
func (r *recentAckPoints) LessRecentPoint() *ackPoint {
if r.ackPoints[0].totalBytesAcked != 0 {
return &r.ackPoints[0]
}
return &r.ackPoints[1]
}
// ConnectionStateOnSentPacket represents the information about a sent packet
// and the state of the connection at the moment the packet was sent,
// specifically the information about the most recently acknowledged packet at
// that moment.
type connectionStateOnSentPacket struct {
// Time at which the packet is sent.
sentTime time.Time
// Size of the packet.
size congestion.ByteCount
// The value of |totalBytesSentAtLastAckedPacket| at the time the
// packet was sent.
totalBytesSentAtLastAckedPacket congestion.ByteCount
// The value of |lastAckedPacketSentTime| at the time the packet was
// sent.
lastAckedPacketSentTime time.Time
// The value of |lastAckedPacketAckTime| at the time the packet was
// sent.
lastAckedPacketAckTime time.Time
// Send time states that are returned to the congestion controller when the
// packet is acked or lost.
sendTimeState sendTimeState
}
// Snapshot constructor. Records the current state of the bandwidth
// sampler.
// |bytes_in_flight| is the bytes in flight right after the packet is sent.
func newConnectionStateOnSentPacket(
sentTime time.Time,
size congestion.ByteCount,
bytesInFlight congestion.ByteCount,
sampler *bandwidthSampler,
) *connectionStateOnSentPacket {
return &connectionStateOnSentPacket{
sentTime: sentTime,
size: size,
totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket,
lastAckedPacketSentTime: sampler.lastAckedPacketSentTime,
lastAckedPacketAckTime: sampler.lastAckedPacketAckTime,
sendTimeState: *newSendTimeState(
sampler.isAppLimited,
sampler.totalBytesSent,
sampler.totalBytesAcked,
sampler.totalBytesLost,
bytesInFlight,
),
}
}
// BandwidthSampler keeps track of sent and acknowledged packets and outputs a
// bandwidth sample for every packet acknowledged. The samples are taken for
// individual packets, and are not filtered; the consumer has to filter the
// bandwidth samples itself. In certain cases, the sampler will locally severely
// underestimate the bandwidth, hence a maximum filter with a size of at least
// one RTT is recommended.
//
// This class bases its samples on the slope of two curves: the number of bytes
// sent over time, and the number of bytes acknowledged as received over time.
// It produces a sample of both slopes for every packet that gets acknowledged,
// based on a slope between two points on each of the corresponding curves. Note
// that due to the packet loss, the number of bytes on each curve might get
// further and further away from each other, meaning that it is not feasible to
// compare byte values coming from different curves with each other.
//
// The obvious points for measuring slope sample are the ones corresponding to
// the packet that was just acknowledged. Let us denote them as S_1 (point at
// which the current packet was sent) and A_1 (point at which the current packet
// was acknowledged). However, taking a slope requires two points on each line,
// so estimating bandwidth requires picking a packet in the past with respect to
// which the slope is measured.
//
// For that purpose, BandwidthSampler always keeps track of the most recently
// acknowledged packet, and records it together with every outgoing packet.
// When a packet gets acknowledged (A_1), it has not only information about when
// it itself was sent (S_1), but also the information about the latest
// acknowledged packet right before it was sent (S_0 and A_0).
//
// Based on that data, send and ack rate are estimated as:
//
// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0))
// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0))
//
// Here, the ack rate is intuitively the rate we want to treat as bandwidth.
// However, in certain cases (e.g. ack compression) the ack rate at a point may
// end up higher than the rate at which the data was originally sent, which is
// not indicative of the real bandwidth. Hence, we use the send rate as an upper
// bound, and the sample value is
//
// rate_sample = min(send_rate, ack_rate)
//
// An important edge case handled by the sampler is tracking the app-limited
// samples. There are multiple meaning of "app-limited" used interchangeably,
// hence it is important to understand and to be able to distinguish between
// them.
//
// Meaning 1: connection state. The connection is said to be app-limited when
// there is no outstanding data to send. This means that certain bandwidth
// samples in the future would not be an accurate indication of the link
// capacity, and it is important to inform consumer about that. Whenever
// connection becomes app-limited, the sampler is notified via OnAppLimited()
// method.
//
// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth
// sampler becomes notified about the connection being app-limited, it enters
// app-limited phase. In that phase, all *sent* packets are marked as
// app-limited. Note that the connection itself does not have to be
// app-limited during the app-limited phase, and in fact it will not be
// (otherwise how would it send packets?). The boolean flag below indicates
// whether the sampler is in that phase.
//
// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is
// sent during the app-limited phase, the resulting sample related to the
// packet will be marked as app-limited.
//
// With the terminology issue out of the way, let us consider the question of
// what kind of situation it addresses.
//
// Consider a scenario where we first send packets 1 to 20 at a regular
// bandwidth, and then immediately run out of data. After a few seconds, we send
// packets 21 to 60, and only receive ack for 21 between sending packets 40 and
// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0
// we use to compute the slope is going to be packet 20, a few seconds apart
// from the current packet, hence the resulting estimate would be extremely low
// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21,
// meaning that the bandwidth sample would exclude the quiescence.
//
// Based on the analysis of that scenario, we implement the following rule: once
// OnAppLimited() is called, all sent packets will produce app-limited samples
// up until an ack for a packet that was sent after OnAppLimited() was called.
// Note that while the scenario above is not the only scenario when the
// connection is app-limited, the approach works in other cases too.
type congestionEventSample struct {
// The maximum bandwidth sample from all acked packets.
// QuicBandwidth::Zero() if no samples are available.
sampleMaxBandwidth Bandwidth
// Whether |sample_max_bandwidth| is from a app-limited sample.
sampleIsAppLimited bool
// The minimum rtt sample from all acked packets.
// QuicTime::Delta::Infinite() if no samples are available.
sampleRtt time.Duration
// For each packet p in acked packets, this is the max value of INFLIGHT(p),
// where INFLIGHT(p) is the number of bytes acked while p is inflight.
sampleMaxInflight congestion.ByteCount
// The send state of the largest packet in acked_packets, unless it is
// empty. If acked_packets is empty, it's the send state of the largest
// packet in lost_packets.
lastPacketSendState sendTimeState
// The number of extra bytes acked from this ack event, compared to what is
// expected from the flow's bandwidth. Larger value means more ack
// aggregation.
extraAcked congestion.ByteCount
}
func newCongestionEventSample() *congestionEventSample {
return &congestionEventSample{
sampleRtt: infRTT,
}
}
type bandwidthSampler struct {
// The total number of congestion controlled bytes sent during the connection.
totalBytesSent congestion.ByteCount
// The total number of congestion controlled bytes which were acknowledged.
totalBytesAcked congestion.ByteCount
// The total number of congestion controlled bytes which were lost.
totalBytesLost congestion.ByteCount
// The total number of congestion controlled bytes which have been neutered.
totalBytesNeutered congestion.ByteCount
// The value of |total_bytes_sent_| at the time the last acknowledged packet
// was sent. Valid only when |last_acked_packet_sent_time_| is valid.
totalBytesSentAtLastAckedPacket congestion.ByteCount
// The time at which the last acknowledged packet was sent. Set to
// QuicTime::Zero() if no valid timestamp is available.
lastAckedPacketSentTime time.Time
// The time at which the most recent packet was acknowledged.
lastAckedPacketAckTime time.Time
// The most recently sent packet.
lastSentPacket congestion.PacketNumber
// The most recently acked packet.
lastAckedPacket congestion.PacketNumber
// Indicates whether the bandwidth sampler is currently in an app-limited
// phase.
isAppLimited bool
// The packet that will be acknowledged after this one will cause the sampler
// to exit the app-limited phase.
endOfAppLimitedPhase congestion.PacketNumber
// Record of the connection state at the point where each packet in flight was
// sent, indexed by the packet number.
connectionStateMap *packetNumberIndexedQueue[connectionStateOnSentPacket]
recentAckPoints recentAckPoints
a0Candidates RingBuffer[ackPoint]
// Maximum number of tracked packets.
maxTrackedPackets congestion.ByteCount
maxAckHeightTracker *maxAckHeightTracker
totalBytesAckedAfterLastAckEvent congestion.ByteCount
// True if connection option 'BSAO' is set.
overestimateAvoidance bool
// True if connection option 'BBRB' is set.
limitMaxAckHeightTrackerBySendRate bool
}
func newBandwidthSampler(maxAckHeightTrackerWindowLength roundTripCount) *bandwidthSampler {
b := &bandwidthSampler{
maxAckHeightTracker: newMaxAckHeightTracker(maxAckHeightTrackerWindowLength),
connectionStateMap: newPacketNumberIndexedQueue[connectionStateOnSentPacket](defaultConnectionStateMapQueueSize),
lastSentPacket: invalidPacketNumber,
lastAckedPacket: invalidPacketNumber,
endOfAppLimitedPhase: invalidPacketNumber,
}
b.a0Candidates.Init(defaultCandidatesBufferSize)
return b
}
func (b *bandwidthSampler) MaxAckHeight() congestion.ByteCount {
return b.maxAckHeightTracker.Get()
}
func (b *bandwidthSampler) NumAckAggregationEpochs() uint64 {
return b.maxAckHeightTracker.NumAckAggregationEpochs()
}
func (b *bandwidthSampler) SetMaxAckHeightTrackerWindowLength(length roundTripCount) {
b.maxAckHeightTracker.SetFilterWindowLength(length)
}
func (b *bandwidthSampler) ResetMaxAckHeightTracker(newHeight congestion.ByteCount, newTime roundTripCount) {
b.maxAckHeightTracker.Reset(newHeight, newTime)
}
func (b *bandwidthSampler) SetStartNewAggregationEpochAfterFullRound(value bool) {
b.maxAckHeightTracker.SetStartNewAggregationEpochAfterFullRound(value)
}
func (b *bandwidthSampler) SetLimitMaxAckHeightTrackerBySendRate(value bool) {
b.limitMaxAckHeightTrackerBySendRate = value
}
func (b *bandwidthSampler) SetReduceExtraAckedOnBandwidthIncrease(value bool) {
b.maxAckHeightTracker.SetReduceExtraAckedOnBandwidthIncrease(value)
}
func (b *bandwidthSampler) EnableOverestimateAvoidance() {
if b.overestimateAvoidance {
return
}
b.overestimateAvoidance = true
b.maxAckHeightTracker.SetAckAggregationBandwidthThreshold(2.0)
}
func (b *bandwidthSampler) IsOverestimateAvoidanceEnabled() bool {
return b.overestimateAvoidance
}
func (b *bandwidthSampler) OnPacketSent(
sentTime time.Time,
packetNumber congestion.PacketNumber,
bytes congestion.ByteCount,
bytesInFlight congestion.ByteCount,
isRetransmittable bool,
) {
b.lastSentPacket = packetNumber
if !isRetransmittable {
return
}
b.totalBytesSent += bytes
// If there are no packets in flight, the time at which the new transmission
// opens can be treated as the A_0 point for the purpose of bandwidth
// sampling. This underestimates bandwidth to some extent, and produces some
// artificially low samples for most packets in flight, but it provides with
// samples at important points where we would not have them otherwise, most
// importantly at the beginning of the connection.
if bytesInFlight == 0 {
b.lastAckedPacketAckTime = sentTime
if b.overestimateAvoidance {
b.recentAckPoints.Clear()
b.recentAckPoints.Update(sentTime, b.totalBytesAcked)
b.a0Candidates.Clear()
b.a0Candidates.PushBack(*b.recentAckPoints.MostRecentPoint())
}
b.totalBytesSentAtLastAckedPacket = b.totalBytesSent
// In this situation ack compression is not a concern, set send rate to
// effectively infinite.
b.lastAckedPacketSentTime = sentTime
}
b.connectionStateMap.Emplace(packetNumber, newConnectionStateOnSentPacket(
sentTime,
bytes,
bytesInFlight+bytes,
b,
))
}
func (b *bandwidthSampler) OnCongestionEvent(
ackTime time.Time,
ackedPackets []congestion.AckedPacketInfo,
lostPackets []congestion.LostPacketInfo,
maxBandwidth Bandwidth,
estBandwidthUpperBound Bandwidth,
roundTripCount roundTripCount,
) congestionEventSample {
eventSample := newCongestionEventSample()
var lastLostPacketSendState sendTimeState
for _, p := range lostPackets {
sendState := b.OnPacketLost(p.PacketNumber, p.BytesLost)
if sendState.isValid {
lastLostPacketSendState = sendState
}
}
if len(ackedPackets) == 0 {
// Only populate send state for a loss-only event.
eventSample.lastPacketSendState = lastLostPacketSendState
return *eventSample
}
var lastAckedPacketSendState sendTimeState
var maxSendRate Bandwidth
for _, p := range ackedPackets {
sample := b.onPacketAcknowledged(ackTime, p.PacketNumber)
if !sample.stateAtSend.isValid {
continue
}
lastAckedPacketSendState = sample.stateAtSend
if sample.rtt != 0 {
eventSample.sampleRtt = min(eventSample.sampleRtt, sample.rtt)
}
if sample.bandwidth > eventSample.sampleMaxBandwidth {
eventSample.sampleMaxBandwidth = sample.bandwidth
eventSample.sampleIsAppLimited = sample.stateAtSend.isAppLimited
}
if sample.sendRate != infBandwidth {
maxSendRate = max(maxSendRate, sample.sendRate)
}
inflightSample := b.totalBytesAcked - lastAckedPacketSendState.totalBytesAcked
if inflightSample > eventSample.sampleMaxInflight {
eventSample.sampleMaxInflight = inflightSample
}
}
if !lastLostPacketSendState.isValid {
eventSample.lastPacketSendState = lastAckedPacketSendState
} else if !lastAckedPacketSendState.isValid {
eventSample.lastPacketSendState = lastLostPacketSendState
} else {
// If two packets are inflight and an alarm is armed to lose a packet and it
// wakes up late, then the first of two in flight packets could have been
// acknowledged before the wakeup, which re-evaluates loss detection, and
// could declare the later of the two lost.
if lostPackets[len(lostPackets)-1].PacketNumber > ackedPackets[len(ackedPackets)-1].PacketNumber {
eventSample.lastPacketSendState = lastLostPacketSendState
} else {
eventSample.lastPacketSendState = lastAckedPacketSendState
}
}
isNewMaxBandwidth := eventSample.sampleMaxBandwidth > maxBandwidth
maxBandwidth = max(maxBandwidth, eventSample.sampleMaxBandwidth)
if b.limitMaxAckHeightTrackerBySendRate {
maxBandwidth = max(maxBandwidth, maxSendRate)
}
eventSample.extraAcked = b.onAckEventEnd(min(estBandwidthUpperBound, maxBandwidth), isNewMaxBandwidth, roundTripCount)
return *eventSample
}
func (b *bandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber, bytesLost congestion.ByteCount) (s sendTimeState) {
b.totalBytesLost += bytesLost
if sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber); sentPacketPointer != nil {
sentPacketToSendTimeState(sentPacketPointer, &s)
}
return s
}
func (b *bandwidthSampler) OnPacketNeutered(packetNumber congestion.PacketNumber) {
b.connectionStateMap.Remove(packetNumber, func(sentPacket connectionStateOnSentPacket) {
b.totalBytesNeutered += sentPacket.size
})
}
func (b *bandwidthSampler) OnAppLimited() {
b.isAppLimited = true
b.endOfAppLimitedPhase = b.lastSentPacket
}
func (b *bandwidthSampler) RemoveObsoletePackets(leastUnacked congestion.PacketNumber) {
// A packet can become obsolete when it is removed from QuicUnackedPacketMap's
// view of inflight before it is acked or marked as lost. For example, when
// QuicSentPacketManager::RetransmitCryptoPackets retransmits a crypto packet,
// the packet is removed from QuicUnackedPacketMap's inflight, but is not
// marked as acked or lost in the BandwidthSampler.
b.connectionStateMap.RemoveUpTo(leastUnacked)
}
func (b *bandwidthSampler) TotalBytesSent() congestion.ByteCount {
return b.totalBytesSent
}
func (b *bandwidthSampler) TotalBytesLost() congestion.ByteCount {
return b.totalBytesLost
}
func (b *bandwidthSampler) TotalBytesAcked() congestion.ByteCount {
return b.totalBytesAcked
}
func (b *bandwidthSampler) TotalBytesNeutered() congestion.ByteCount {
return b.totalBytesNeutered
}
func (b *bandwidthSampler) IsAppLimited() bool {
return b.isAppLimited
}
func (b *bandwidthSampler) EndOfAppLimitedPhase() congestion.PacketNumber {
return b.endOfAppLimitedPhase
}
func (b *bandwidthSampler) max_ack_height() congestion.ByteCount {
return b.maxAckHeightTracker.Get()
}
func (b *bandwidthSampler) chooseA0Point(totalBytesAcked congestion.ByteCount, a0 *ackPoint) bool {
if b.a0Candidates.Empty() {
return false
}
if b.a0Candidates.Len() == 1 {
*a0 = *b.a0Candidates.Front()
return true
}
for i := 1; i < b.a0Candidates.Len(); i++ {
if b.a0Candidates.Offset(i).totalBytesAcked > totalBytesAcked {
*a0 = *b.a0Candidates.Offset(i - 1)
if i > 1 {
for j := 0; j < i-1; j++ {
b.a0Candidates.PopFront()
}
}
return true
}
}
*a0 = *b.a0Candidates.Back()
for k := 0; k < b.a0Candidates.Len()-1; k++ {
b.a0Candidates.PopFront()
}
return true
}
func (b *bandwidthSampler) onPacketAcknowledged(ackTime time.Time, packetNumber congestion.PacketNumber) bandwidthSample {
sample := newBandwidthSample()
b.lastAckedPacket = packetNumber
sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber)
if sentPacketPointer == nil {
return *sample
}
// OnPacketAcknowledgedInner
b.totalBytesAcked += sentPacketPointer.size
b.totalBytesSentAtLastAckedPacket = sentPacketPointer.sendTimeState.totalBytesSent
b.lastAckedPacketSentTime = sentPacketPointer.sentTime
b.lastAckedPacketAckTime = ackTime
if b.overestimateAvoidance {
b.recentAckPoints.Update(ackTime, b.totalBytesAcked)
}
if b.isAppLimited {
// Exit app-limited phase in two cases:
// (1) end_of_app_limited_phase_ is not initialized, i.e., so far all
// packets are sent while there are buffered packets or pending data.
// (2) The current acked packet is after the sent packet marked as the end
// of the app limit phase.
if b.endOfAppLimitedPhase == invalidPacketNumber ||
packetNumber > b.endOfAppLimitedPhase {
b.isAppLimited = false
}
}
// There might have been no packets acknowledged at the moment when the
// current packet was sent. In that case, there is no bandwidth sample to
// make.
if sentPacketPointer.lastAckedPacketSentTime.IsZero() {
return *sample
}
// Infinite rate indicates that the sampler is supposed to discard the
// current send rate sample and use only the ack rate.
sendRate := infBandwidth
if sentPacketPointer.sentTime.After(sentPacketPointer.lastAckedPacketSentTime) {
sendRate = BandwidthFromDelta(
sentPacketPointer.sendTimeState.totalBytesSent-sentPacketPointer.totalBytesSentAtLastAckedPacket,
sentPacketPointer.sentTime.Sub(sentPacketPointer.lastAckedPacketSentTime))
}
var a0 ackPoint
if b.overestimateAvoidance && b.chooseA0Point(sentPacketPointer.sendTimeState.totalBytesAcked, &a0) {
} else {
a0.ackTime = sentPacketPointer.lastAckedPacketAckTime
a0.totalBytesAcked = sentPacketPointer.sendTimeState.totalBytesAcked
}
// During the slope calculation, ensure that ack time of the current packet is
// always larger than the time of the previous packet, otherwise division by
// zero or integer underflow can occur.
if ackTime.Sub(a0.ackTime) <= 0 {
return *sample
}
ackRate := BandwidthFromDelta(b.totalBytesAcked-a0.totalBytesAcked, ackTime.Sub(a0.ackTime))
sample.bandwidth = min(sendRate, ackRate)
// Note: this sample does not account for delayed acknowledgement time. This
// means that the RTT measurements here can be artificially high, especially
// on low bandwidth connections.
sample.rtt = ackTime.Sub(sentPacketPointer.sentTime)
sample.sendRate = sendRate
sentPacketToSendTimeState(sentPacketPointer, &sample.stateAtSend)
return *sample
}
func (b *bandwidthSampler) onAckEventEnd(
bandwidthEstimate Bandwidth,
isNewMaxBandwidth bool,
roundTripCount roundTripCount,
) congestion.ByteCount {
newlyAckedBytes := b.totalBytesAcked - b.totalBytesAckedAfterLastAckEvent
if newlyAckedBytes == 0 {
return 0
}
b.totalBytesAckedAfterLastAckEvent = b.totalBytesAcked
extraAcked := b.maxAckHeightTracker.Update(
bandwidthEstimate,
isNewMaxBandwidth,
roundTripCount,
b.lastSentPacket,
b.lastAckedPacket,
b.lastAckedPacketAckTime,
newlyAckedBytes)
// If |extra_acked| is zero, i.e. this ack event marks the start of a new ack
// aggregation epoch, save LessRecentPoint, which is the last ack point of the
// previous epoch, as a A0 candidate.
if b.overestimateAvoidance && extraAcked == 0 {
b.a0Candidates.PushBack(*b.recentAckPoints.LessRecentPoint())
}
return extraAcked
}
func sentPacketToSendTimeState(sentPacket *connectionStateOnSentPacket, sendTimeState *sendTimeState) {
*sendTimeState = sentPacket.sendTimeState
sendTimeState.isValid = true
}
// BytesFromBandwidthAndTimeDelta calculates the bytes
// from a bandwidth(bits per second) and a time delta
func bytesFromBandwidthAndTimeDelta(bandwidth Bandwidth, delta time.Duration) congestion.ByteCount {
return (congestion.ByteCount(bandwidth) * congestion.ByteCount(delta)) /
(congestion.ByteCount(time.Second) * 8)
}
func timeDeltaFromBytesAndBandwidth(bytes congestion.ByteCount, bandwidth Bandwidth) time.Duration {
return time.Duration(bytes*8) * time.Second / time.Duration(bandwidth)
}

View file

@ -0,0 +1,944 @@
package bbr
import (
"fmt"
"math/rand"
"net"
"time"
"github.com/metacubex/quic-go/congestion"
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion/common"
)
// BbrSender implements BBR congestion control algorithm. BBR aims to estimate
// the current available Bottleneck Bandwidth and RTT (hence the name), and
// regulates the pacing rate and the size of the congestion window based on
// those signals.
//
// BBR relies on pacing in order to function properly. Do not use BBR when
// pacing is disabled.
//
const (
minBps = 65536 // 64 kbps
invalidPacketNumber = -1
initialCongestionWindowPackets = 32
// Constants based on TCP defaults.
// The minimum CWND to ensure delayed acks don't reduce bandwidth measurements.
// Does not inflate the pacing rate.
defaultMinimumCongestionWindow = 4 * congestion.ByteCount(congestion.InitialPacketSizeIPv4)
// The gain used for the STARTUP, equal to 2/ln(2).
defaultHighGain = 2.885
// The newly derived gain for STARTUP, equal to 4 * ln(2)
derivedHighGain = 2.773
// The newly derived CWND gain for STARTUP, 2.
derivedHighCWNDGain = 2.0
)
// The cycle of gains used during the PROBE_BW stage.
var pacingGain = [...]float64{1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}
const (
// The length of the gain cycle.
gainCycleLength = len(pacingGain)
// The size of the bandwidth filter window, in round-trips.
bandwidthWindowSize = gainCycleLength + 2
// The time after which the current min_rtt value expires.
minRttExpiry = 10 * time.Second
// The minimum time the connection can spend in PROBE_RTT mode.
probeRttTime = 200 * time.Millisecond
// If the bandwidth does not increase by the factor of |kStartupGrowthTarget|
// within |kRoundTripsWithoutGrowthBeforeExitingStartup| rounds, the connection
// will exit the STARTUP mode.
startupGrowthTarget = 1.25
roundTripsWithoutGrowthBeforeExitingStartup = int64(3)
// Flag.
defaultStartupFullLossCount = 8
quicBbr2DefaultLossThreshold = 0.02
maxBbrBurstPackets = 3
)
type bbrMode int
const (
// Startup phase of the connection.
bbrModeStartup = iota
// After achieving the highest possible bandwidth during the startup, lower
// the pacing rate in order to drain the queue.
bbrModeDrain
// Cruising mode.
bbrModeProbeBw
// Temporarily slow down sending in order to empty the buffer and measure
// the real minimum RTT.
bbrModeProbeRtt
)
// Indicates how the congestion control limits the amount of bytes in flight.
type bbrRecoveryState int
const (
// Do not limit.
bbrRecoveryStateNotInRecovery = iota
// Allow an extra outstanding byte for each byte acknowledged.
bbrRecoveryStateConservation
// Allow two extra outstanding bytes for each byte acknowledged (slow
// start).
bbrRecoveryStateGrowth
)
type bbrSender struct {
rttStats congestion.RTTStatsProvider
clock Clock
pacer *common.Pacer
mode bbrMode
// Bandwidth sampler provides BBR with the bandwidth measurements at
// individual points.
sampler *bandwidthSampler
// The number of the round trips that have occurred during the connection.
roundTripCount roundTripCount
// The packet number of the most recently sent packet.
lastSentPacket congestion.PacketNumber
// Acknowledgement of any packet after |current_round_trip_end_| will cause
// the round trip counter to advance.
currentRoundTripEnd congestion.PacketNumber
// Number of congestion events with some losses, in the current round.
numLossEventsInRound uint64
// Number of total bytes lost in the current round.
bytesLostInRound congestion.ByteCount
// The filter that tracks the maximum bandwidth over the multiple recent
// round-trips.
maxBandwidth *WindowedFilter[Bandwidth, roundTripCount]
// Minimum RTT estimate. Automatically expires within 10 seconds (and
// triggers PROBE_RTT mode) if no new value is sampled during that period.
minRtt time.Duration
// The time at which the current value of |min_rtt_| was assigned.
minRttTimestamp time.Time
// The maximum allowed number of bytes in flight.
congestionWindow congestion.ByteCount
// The initial value of the |congestion_window_|.
initialCongestionWindow congestion.ByteCount
// The largest value the |congestion_window_| can achieve.
maxCongestionWindow congestion.ByteCount
// The smallest value the |congestion_window_| can achieve.
minCongestionWindow congestion.ByteCount
// The pacing gain applied during the STARTUP phase.
highGain float64
// The CWND gain applied during the STARTUP phase.
highCwndGain float64
// The pacing gain applied during the DRAIN phase.
drainGain float64
// The current pacing rate of the connection.
pacingRate Bandwidth
// The gain currently applied to the pacing rate.
pacingGain float64
// The gain currently applied to the congestion window.
congestionWindowGain float64
// The gain used for the congestion window during PROBE_BW. Latched from
// quic_bbr_cwnd_gain flag.
congestionWindowGainConstant float64
// The number of RTTs to stay in STARTUP mode. Defaults to 3.
numStartupRtts int64
// Number of round-trips in PROBE_BW mode, used for determining the current
// pacing gain cycle.
cycleCurrentOffset int
// The time at which the last pacing gain cycle was started.
lastCycleStart time.Time
// Indicates whether the connection has reached the full bandwidth mode.
isAtFullBandwidth bool
// Number of rounds during which there was no significant bandwidth increase.
roundsWithoutBandwidthGain int64
// The bandwidth compared to which the increase is measured.
bandwidthAtLastRound Bandwidth
// Set to true upon exiting quiescence.
exitingQuiescence bool
// Time at which PROBE_RTT has to be exited. Setting it to zero indicates
// that the time is yet unknown as the number of packets in flight has not
// reached the required value.
exitProbeRttAt time.Time
// Indicates whether a round-trip has passed since PROBE_RTT became active.
probeRttRoundPassed bool
// Indicates whether the most recent bandwidth sample was marked as
// app-limited.
lastSampleIsAppLimited bool
// Indicates whether any non app-limited samples have been recorded.
hasNoAppLimitedSample bool
// Current state of recovery.
recoveryState bbrRecoveryState
// Receiving acknowledgement of a packet after |end_recovery_at_| will cause
// BBR to exit the recovery mode. A value above zero indicates at least one
// loss has been detected, so it must not be set back to zero.
endRecoveryAt congestion.PacketNumber
// A window used to limit the number of bytes in flight during loss recovery.
recoveryWindow congestion.ByteCount
// If true, consider all samples in recovery app-limited.
isAppLimitedRecovery bool // not used
// When true, pace at 1.5x and disable packet conservation in STARTUP.
slowerStartup bool // not used
// When true, disables packet conservation in STARTUP.
rateBasedStartup bool // not used
// When true, add the most recent ack aggregation measurement during STARTUP.
enableAckAggregationDuringStartup bool
// When true, expire the windowed ack aggregation values in STARTUP when
// bandwidth increases more than 25%.
expireAckAggregationInStartup bool
// If true, will not exit low gain mode until bytes_in_flight drops below BDP
// or it's time for high gain mode.
drainToTarget bool
// If true, slow down pacing rate in STARTUP when overshooting is detected.
detectOvershooting bool
// Bytes lost while detect_overshooting_ is true.
bytesLostWhileDetectingOvershooting congestion.ByteCount
// Slow down pacing rate if
// bytes_lost_while_detecting_overshooting_ *
// bytes_lost_multiplier_while_detecting_overshooting_ > IW.
bytesLostMultiplierWhileDetectingOvershooting uint8
// When overshooting is detected, do not drop pacing_rate_ below this value /
// min_rtt.
cwndToCalculateMinPacingRate congestion.ByteCount
// Max congestion window when adjusting network parameters.
maxCongestionWindowWithNetworkParametersAdjusted congestion.ByteCount // not used
// Params.
maxDatagramSize congestion.ByteCount
// Recorded on packet sent. equivalent |unacked_packets_->bytes_in_flight()|
bytesInFlight congestion.ByteCount
}
var _ congestion.CongestionControl = &bbrSender{}
func NewBbrSender(
clock Clock,
initialMaxDatagramSize congestion.ByteCount,
) *bbrSender {
return newBbrSender(
clock,
initialMaxDatagramSize,
initialCongestionWindowPackets*initialMaxDatagramSize,
congestion.MaxCongestionWindowPackets*initialMaxDatagramSize,
)
}
func newBbrSender(
clock Clock,
initialMaxDatagramSize,
initialCongestionWindow,
initialMaxCongestionWindow congestion.ByteCount,
) *bbrSender {
b := &bbrSender{
clock: clock,
mode: bbrModeStartup,
sampler: newBandwidthSampler(roundTripCount(bandwidthWindowSize)),
lastSentPacket: invalidPacketNumber,
currentRoundTripEnd: invalidPacketNumber,
maxBandwidth: NewWindowedFilter(roundTripCount(bandwidthWindowSize), MaxFilter[Bandwidth]),
congestionWindow: initialCongestionWindow,
initialCongestionWindow: initialCongestionWindow,
maxCongestionWindow: initialMaxCongestionWindow,
minCongestionWindow: defaultMinimumCongestionWindow,
highGain: defaultHighGain,
highCwndGain: defaultHighGain,
drainGain: 1.0 / defaultHighGain,
pacingGain: 1.0,
congestionWindowGain: 1.0,
congestionWindowGainConstant: 2.0,
numStartupRtts: roundTripsWithoutGrowthBeforeExitingStartup,
recoveryState: bbrRecoveryStateNotInRecovery,
endRecoveryAt: invalidPacketNumber,
recoveryWindow: initialMaxCongestionWindow,
bytesLostMultiplierWhileDetectingOvershooting: 2,
cwndToCalculateMinPacingRate: initialCongestionWindow,
maxCongestionWindowWithNetworkParametersAdjusted: initialMaxCongestionWindow,
maxDatagramSize: initialMaxDatagramSize,
}
b.pacer = common.NewPacer(b.bandwidthForPacer)
/*
if b.tracer != nil {
b.lastState = logging.CongestionStateStartup
b.tracer.UpdatedCongestionState(logging.CongestionStateStartup)
}
*/
b.enterStartupMode(b.clock.Now())
b.setHighCwndGain(derivedHighCWNDGain)
return b
}
func (b *bbrSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
b.rttStats = provider
}
// TimeUntilSend implements the SendAlgorithm interface.
func (b *bbrSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
return b.pacer.TimeUntilSend()
}
// HasPacingBudget implements the SendAlgorithm interface.
func (b *bbrSender) HasPacingBudget(now time.Time) bool {
return b.pacer.Budget(now) >= b.maxDatagramSize
}
// OnPacketSent implements the SendAlgorithm interface.
func (b *bbrSender) OnPacketSent(
sentTime time.Time,
bytesInFlight congestion.ByteCount,
packetNumber congestion.PacketNumber,
bytes congestion.ByteCount,
isRetransmittable bool,
) {
b.pacer.SentPacket(sentTime, bytes)
b.lastSentPacket = packetNumber
b.bytesInFlight = bytesInFlight
if bytesInFlight == 0 {
b.exitingQuiescence = true
}
b.sampler.OnPacketSent(sentTime, packetNumber, bytes, bytesInFlight, isRetransmittable)
}
// CanSend implements the SendAlgorithm interface.
func (b *bbrSender) CanSend(bytesInFlight congestion.ByteCount) bool {
return bytesInFlight < b.GetCongestionWindow()
}
// MaybeExitSlowStart implements the SendAlgorithm interface.
func (b *bbrSender) MaybeExitSlowStart() {
// Do nothing
}
// OnPacketAcked implements the SendAlgorithm interface.
func (b *bbrSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes, priorInFlight congestion.ByteCount, eventTime time.Time) {
// Do nothing.
}
// OnPacketLost implements the SendAlgorithm interface.
func (b *bbrSender) OnPacketLost(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
// Do nothing.
}
// OnRetransmissionTimeout implements the SendAlgorithm interface.
func (b *bbrSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
// Do nothing.
}
// SetMaxDatagramSize implements the SendAlgorithm interface.
func (b *bbrSender) SetMaxDatagramSize(s congestion.ByteCount) {
if s < b.maxDatagramSize {
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", b.maxDatagramSize, s))
}
cwndIsMinCwnd := b.congestionWindow == b.minCongestionWindow
b.maxDatagramSize = s
if cwndIsMinCwnd {
b.congestionWindow = b.minCongestionWindow
}
b.pacer.SetMaxDatagramSize(s)
}
// InSlowStart implements the SendAlgorithmWithDebugInfos interface.
func (b *bbrSender) InSlowStart() bool {
return b.mode == bbrModeStartup
}
// InRecovery implements the SendAlgorithmWithDebugInfos interface.
func (b *bbrSender) InRecovery() bool {
return b.recoveryState != bbrRecoveryStateNotInRecovery
}
// GetCongestionWindow implements the SendAlgorithmWithDebugInfos interface.
func (b *bbrSender) GetCongestionWindow() congestion.ByteCount {
if b.mode == bbrModeProbeRtt {
return b.probeRttCongestionWindow()
}
if b.InRecovery() {
return min(b.congestionWindow, b.recoveryWindow)
}
return b.congestionWindow
}
func (b *bbrSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
// Do nothing.
}
func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) {
totalBytesAckedBefore := b.sampler.TotalBytesAcked()
totalBytesLostBefore := b.sampler.TotalBytesLost()
var isRoundStart, minRttExpired bool
var excessAcked, bytesLost congestion.ByteCount
// The send state of the largest packet in acked_packets, unless it is
// empty. If acked_packets is empty, it's the send state of the largest
// packet in lost_packets.
var lastPacketSendState sendTimeState
b.maybeApplimited(priorInFlight)
// Update bytesInFlight
b.bytesInFlight = priorInFlight
for _, p := range ackedPackets {
b.bytesInFlight -= p.BytesAcked
}
for _, p := range lostPackets {
b.bytesInFlight -= p.BytesLost
}
if len(ackedPackets) != 0 {
lastAckedPacket := ackedPackets[len(ackedPackets)-1].PacketNumber
isRoundStart = b.updateRoundTripCounter(lastAckedPacket)
b.updateRecoveryState(lastAckedPacket, len(lostPackets) != 0, isRoundStart)
}
sample := b.sampler.OnCongestionEvent(eventTime,
ackedPackets, lostPackets, b.maxBandwidth.GetBest(), infBandwidth, b.roundTripCount)
if sample.lastPacketSendState.isValid {
b.lastSampleIsAppLimited = sample.lastPacketSendState.isAppLimited
b.hasNoAppLimitedSample = b.hasNoAppLimitedSample || !b.lastSampleIsAppLimited
}
// Avoid updating |max_bandwidth_| if a) this is a loss-only event, or b) all
// packets in |acked_packets| did not generate valid samples. (e.g. ack of
// ack-only packets). In both cases, sampler_.total_bytes_acked() will not
// change.
if totalBytesAckedBefore != b.sampler.TotalBytesAcked() {
if !sample.sampleIsAppLimited || sample.sampleMaxBandwidth > b.maxBandwidth.GetBest() {
b.maxBandwidth.Update(sample.sampleMaxBandwidth, b.roundTripCount)
}
}
if sample.sampleRtt != infRTT {
minRttExpired = b.maybeUpdateMinRtt(eventTime, sample.sampleRtt)
}
bytesLost = b.sampler.TotalBytesLost() - totalBytesLostBefore
excessAcked = sample.extraAcked
lastPacketSendState = sample.lastPacketSendState
if len(lostPackets) != 0 {
b.numLossEventsInRound++
b.bytesLostInRound += bytesLost
}
// Handle logic specific to PROBE_BW mode.
if b.mode == bbrModeProbeBw {
b.updateGainCyclePhase(eventTime, priorInFlight, len(lostPackets) != 0)
}
// Handle logic specific to STARTUP and DRAIN modes.
if isRoundStart && !b.isAtFullBandwidth {
b.checkIfFullBandwidthReached(&lastPacketSendState)
}
b.maybeExitStartupOrDrain(eventTime)
// Handle logic specific to PROBE_RTT.
b.maybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired)
// Calculate number of packets acked and lost.
bytesAcked := b.sampler.TotalBytesAcked() - totalBytesAckedBefore
// After the model is updated, recalculate the pacing rate and congestion
// window.
b.calculatePacingRate(bytesLost)
b.calculateCongestionWindow(bytesAcked, excessAcked)
b.calculateRecoveryWindow(bytesAcked, bytesLost)
// Cleanup internal state.
// This is where we clean up obsolete (acked or lost) packets from the bandwidth sampler.
// The "least unacked" should actually be FirstOutstanding, but since we are not passing
// that through OnCongestionEventEx, we will only do an estimate using acked/lost packets
// for now. Because of fast retransmission, they should differ by no more than 2 packets.
// (this is controlled by packetThreshold in quic-go's sentPacketHandler)
var leastUnacked congestion.PacketNumber
if len(ackedPackets) != 0 {
leastUnacked = ackedPackets[len(ackedPackets)-1].PacketNumber - 2
} else {
leastUnacked = lostPackets[len(lostPackets)-1].PacketNumber + 1
}
b.sampler.RemoveObsoletePackets(leastUnacked)
if isRoundStart {
b.numLossEventsInRound = 0
b.bytesLostInRound = 0
}
}
func (b *bbrSender) PacingRate() Bandwidth {
if b.pacingRate == 0 {
return Bandwidth(b.highGain * float64(
BandwidthFromDelta(b.initialCongestionWindow, b.getMinRtt())))
}
return b.pacingRate
}
func (b *bbrSender) hasGoodBandwidthEstimateForResumption() bool {
return b.hasNonAppLimitedSample()
}
func (b *bbrSender) hasNonAppLimitedSample() bool {
return b.hasNoAppLimitedSample
}
// Sets the pacing gain used in STARTUP. Must be greater than 1.
func (b *bbrSender) setHighGain(highGain float64) {
b.highGain = highGain
if b.mode == bbrModeStartup {
b.pacingGain = highGain
}
}
// Sets the CWND gain used in STARTUP. Must be greater than 1.
func (b *bbrSender) setHighCwndGain(highCwndGain float64) {
b.highCwndGain = highCwndGain
if b.mode == bbrModeStartup {
b.congestionWindowGain = highCwndGain
}
}
// Sets the gain used in DRAIN. Must be less than 1.
func (b *bbrSender) setDrainGain(drainGain float64) {
b.drainGain = drainGain
}
// What's the current estimated bandwidth in bytes per second.
func (b *bbrSender) bandwidthEstimate() Bandwidth {
return b.maxBandwidth.GetBest()
}
func (b *bbrSender) bandwidthForPacer() congestion.ByteCount {
bps := congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond))
if bps < minBps {
// We need to make sure that the bandwidth value for pacer is never zero,
// otherwise it will go into an edge case where HasPacingBudget = false
// but TimeUntilSend is before, causing the quic-go send loop to go crazy and get stuck.
return minBps
}
return bps
}
// Returns the current estimate of the RTT of the connection. Outside of the
// edge cases, this is minimum RTT.
func (b *bbrSender) getMinRtt() time.Duration {
if b.minRtt != 0 {
return b.minRtt
}
// min_rtt could be available if the handshake packet gets neutered then
// gets acknowledged. This could only happen for QUIC crypto where we do not
// drop keys.
minRtt := b.rttStats.MinRTT()
if minRtt == 0 {
return 100 * time.Millisecond
} else {
return minRtt
}
}
// Computes the target congestion window using the specified gain.
func (b *bbrSender) getTargetCongestionWindow(gain float64) congestion.ByteCount {
bdp := bdpFromRttAndBandwidth(b.getMinRtt(), b.bandwidthEstimate())
congestionWindow := congestion.ByteCount(gain * float64(bdp))
// BDP estimate will be zero if no bandwidth samples are available yet.
if congestionWindow == 0 {
congestionWindow = congestion.ByteCount(gain * float64(b.initialCongestionWindow))
}
return max(congestionWindow, b.minCongestionWindow)
}
// The target congestion window during PROBE_RTT.
func (b *bbrSender) probeRttCongestionWindow() congestion.ByteCount {
return b.minCongestionWindow
}
func (b *bbrSender) maybeUpdateMinRtt(now time.Time, sampleMinRtt time.Duration) bool {
// Do not expire min_rtt if none was ever available.
minRttExpired := b.minRtt != 0 && now.After(b.minRttTimestamp.Add(minRttExpiry))
if minRttExpired || sampleMinRtt < b.minRtt || b.minRtt == 0 {
b.minRtt = sampleMinRtt
b.minRttTimestamp = now
}
return minRttExpired
}
// Enters the STARTUP mode.
func (b *bbrSender) enterStartupMode(now time.Time) {
b.mode = bbrModeStartup
// b.maybeTraceStateChange(logging.CongestionStateStartup)
b.pacingGain = b.highGain
b.congestionWindowGain = b.highCwndGain
}
// Enters the PROBE_BW mode.
func (b *bbrSender) enterProbeBandwidthMode(now time.Time) {
b.mode = bbrModeProbeBw
// b.maybeTraceStateChange(logging.CongestionStateProbeBw)
b.congestionWindowGain = b.congestionWindowGainConstant
// Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is
// excluded because in that case increased gain and decreased gain would not
// follow each other.
b.cycleCurrentOffset = int(rand.Int31n(congestion.PacketsPerConnectionID)) % (gainCycleLength - 1)
if b.cycleCurrentOffset >= 1 {
b.cycleCurrentOffset += 1
}
b.lastCycleStart = now
b.pacingGain = pacingGain[b.cycleCurrentOffset]
}
// Updates the round-trip counter if a round-trip has passed. Returns true if
// the counter has been advanced.
func (b *bbrSender) updateRoundTripCounter(lastAckedPacket congestion.PacketNumber) bool {
if b.currentRoundTripEnd == invalidPacketNumber || lastAckedPacket > b.currentRoundTripEnd {
b.roundTripCount++
b.currentRoundTripEnd = b.lastSentPacket
return true
}
return false
}
// Updates the current gain used in PROBE_BW mode.
func (b *bbrSender) updateGainCyclePhase(now time.Time, priorInFlight congestion.ByteCount, hasLosses bool) {
// In most cases, the cycle is advanced after an RTT passes.
shouldAdvanceGainCycling := now.After(b.lastCycleStart.Add(b.getMinRtt()))
// If the pacing gain is above 1.0, the connection is trying to probe the
// bandwidth by increasing the number of bytes in flight to at least
// pacing_gain * BDP. Make sure that it actually reaches the target, as long
// as there are no losses suggesting that the buffers are not able to hold
// that much.
if b.pacingGain > 1.0 && !hasLosses && priorInFlight < b.getTargetCongestionWindow(b.pacingGain) {
shouldAdvanceGainCycling = false
}
// If pacing gain is below 1.0, the connection is trying to drain the extra
// queue which could have been incurred by probing prior to it. If the number
// of bytes in flight falls down to the estimated BDP value earlier, conclude
// that the queue has been successfully drained and exit this cycle early.
if b.pacingGain < 1.0 && b.bytesInFlight <= b.getTargetCongestionWindow(1) {
shouldAdvanceGainCycling = true
}
if shouldAdvanceGainCycling {
b.cycleCurrentOffset = (b.cycleCurrentOffset + 1) % gainCycleLength
b.lastCycleStart = now
// Stay in low gain mode until the target BDP is hit.
// Low gain mode will be exited immediately when the target BDP is achieved.
if b.drainToTarget && b.pacingGain < 1 &&
pacingGain[b.cycleCurrentOffset] == 1 &&
b.bytesInFlight > b.getTargetCongestionWindow(1) {
return
}
b.pacingGain = pacingGain[b.cycleCurrentOffset]
}
}
// Tracks for how many round-trips the bandwidth has not increased
// significantly.
func (b *bbrSender) checkIfFullBandwidthReached(lastPacketSendState *sendTimeState) {
if b.lastSampleIsAppLimited {
return
}
target := Bandwidth(float64(b.bandwidthAtLastRound) * startupGrowthTarget)
if b.bandwidthEstimate() >= target {
b.bandwidthAtLastRound = b.bandwidthEstimate()
b.roundsWithoutBandwidthGain = 0
if b.expireAckAggregationInStartup {
// Expire old excess delivery measurements now that bandwidth increased.
b.sampler.ResetMaxAckHeightTracker(0, b.roundTripCount)
}
return
}
b.roundsWithoutBandwidthGain++
if b.roundsWithoutBandwidthGain >= b.numStartupRtts ||
b.shouldExitStartupDueToLoss(lastPacketSendState) {
b.isAtFullBandwidth = true
}
}
func (b *bbrSender) maybeApplimited(bytesInFlight congestion.ByteCount) {
congestionWindow := b.GetCongestionWindow()
if bytesInFlight >= congestionWindow {
return
}
availableBytes := congestionWindow - bytesInFlight
drainLimited := b.mode == bbrModeDrain && bytesInFlight > congestionWindow/2
if !drainLimited || availableBytes > maxBbrBurstPackets*b.maxDatagramSize {
b.sampler.OnAppLimited()
}
}
// Transitions from STARTUP to DRAIN and from DRAIN to PROBE_BW if
// appropriate.
func (b *bbrSender) maybeExitStartupOrDrain(now time.Time) {
if b.mode == bbrModeStartup && b.isAtFullBandwidth {
b.mode = bbrModeDrain
// b.maybeTraceStateChange(logging.CongestionStateDrain)
b.pacingGain = b.drainGain
b.congestionWindowGain = b.highCwndGain
}
if b.mode == bbrModeDrain && b.bytesInFlight <= b.getTargetCongestionWindow(1) {
b.enterProbeBandwidthMode(now)
}
}
// Decides whether to enter or exit PROBE_RTT.
func (b *bbrSender) maybeEnterOrExitProbeRtt(now time.Time, isRoundStart, minRttExpired bool) {
if minRttExpired && !b.exitingQuiescence && b.mode != bbrModeProbeRtt {
b.mode = bbrModeProbeRtt
// b.maybeTraceStateChange(logging.CongestionStateProbRtt)
b.pacingGain = 1.0
// Do not decide on the time to exit PROBE_RTT until the |bytes_in_flight|
// is at the target small value.
b.exitProbeRttAt = time.Time{}
}
if b.mode == bbrModeProbeRtt {
b.sampler.OnAppLimited()
// b.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
if b.exitProbeRttAt.IsZero() {
// If the window has reached the appropriate size, schedule exiting
// PROBE_RTT. The CWND during PROBE_RTT is kMinimumCongestionWindow, but
// we allow an extra packet since QUIC checks CWND before sending a
// packet.
if b.bytesInFlight < b.probeRttCongestionWindow()+congestion.MaxPacketBufferSize {
b.exitProbeRttAt = now.Add(probeRttTime)
b.probeRttRoundPassed = false
}
} else {
if isRoundStart {
b.probeRttRoundPassed = true
}
if now.Sub(b.exitProbeRttAt) >= 0 && b.probeRttRoundPassed {
b.minRttTimestamp = now
if !b.isAtFullBandwidth {
b.enterStartupMode(now)
} else {
b.enterProbeBandwidthMode(now)
}
}
}
}
b.exitingQuiescence = false
}
// Determines whether BBR needs to enter, exit or advance state of the
// recovery.
func (b *bbrSender) updateRecoveryState(lastAckedPacket congestion.PacketNumber, hasLosses, isRoundStart bool) {
// Disable recovery in startup, if loss-based exit is enabled.
if !b.isAtFullBandwidth {
return
}
// Exit recovery when there are no losses for a round.
if hasLosses {
b.endRecoveryAt = b.lastSentPacket
}
switch b.recoveryState {
case bbrRecoveryStateNotInRecovery:
if hasLosses {
b.recoveryState = bbrRecoveryStateConservation
// This will cause the |recovery_window_| to be set to the correct
// value in CalculateRecoveryWindow().
b.recoveryWindow = 0
// Since the conservation phase is meant to be lasting for a whole
// round, extend the current round as if it were started right now.
b.currentRoundTripEnd = b.lastSentPacket
}
case bbrRecoveryStateConservation:
if isRoundStart {
b.recoveryState = bbrRecoveryStateGrowth
}
fallthrough
case bbrRecoveryStateGrowth:
// Exit recovery if appropriate.
if !hasLosses && lastAckedPacket > b.endRecoveryAt {
b.recoveryState = bbrRecoveryStateNotInRecovery
}
}
}
// Determines the appropriate pacing rate for the connection.
func (b *bbrSender) calculatePacingRate(bytesLost congestion.ByteCount) {
if b.bandwidthEstimate() == 0 {
return
}
targetRate := Bandwidth(b.pacingGain * float64(b.bandwidthEstimate()))
if b.isAtFullBandwidth {
b.pacingRate = targetRate
return
}
// Pace at the rate of initial_window / RTT as soon as RTT measurements are
// available.
if b.pacingRate == 0 && b.rttStats.MinRTT() != 0 {
b.pacingRate = BandwidthFromDelta(b.initialCongestionWindow, b.rttStats.MinRTT())
return
}
if b.detectOvershooting {
b.bytesLostWhileDetectingOvershooting += bytesLost
// Check for overshooting with network parameters adjusted when pacing rate
// > target_rate and loss has been detected.
if b.pacingRate > targetRate && b.bytesLostWhileDetectingOvershooting > 0 {
if b.hasNoAppLimitedSample ||
b.bytesLostWhileDetectingOvershooting*congestion.ByteCount(b.bytesLostMultiplierWhileDetectingOvershooting) > b.initialCongestionWindow {
// We are fairly sure overshoot happens if 1) there is at least one
// non app-limited bw sample or 2) half of IW gets lost. Slow pacing
// rate.
b.pacingRate = max(targetRate, BandwidthFromDelta(b.cwndToCalculateMinPacingRate, b.rttStats.MinRTT()))
b.bytesLostWhileDetectingOvershooting = 0
b.detectOvershooting = false
}
}
}
// Do not decrease the pacing rate during startup.
b.pacingRate = max(b.pacingRate, targetRate)
}
// Determines the appropriate congestion window for the connection.
func (b *bbrSender) calculateCongestionWindow(bytesAcked, excessAcked congestion.ByteCount) {
if b.mode == bbrModeProbeRtt {
return
}
targetWindow := b.getTargetCongestionWindow(b.congestionWindowGain)
if b.isAtFullBandwidth {
// Add the max recently measured ack aggregation to CWND.
targetWindow += b.sampler.MaxAckHeight()
} else if b.enableAckAggregationDuringStartup {
// Add the most recent excess acked. Because CWND never decreases in
// STARTUP, this will automatically create a very localized max filter.
targetWindow += excessAcked
}
// Instead of immediately setting the target CWND as the new one, BBR grows
// the CWND towards |target_window| by only increasing it |bytes_acked| at a
// time.
if b.isAtFullBandwidth {
b.congestionWindow = min(targetWindow, b.congestionWindow+bytesAcked)
} else if b.congestionWindow < targetWindow ||
b.sampler.TotalBytesAcked() < b.initialCongestionWindow {
// If the connection is not yet out of startup phase, do not decrease the
// window.
b.congestionWindow += bytesAcked
}
// Enforce the limits on the congestion window.
b.congestionWindow = max(b.congestionWindow, b.minCongestionWindow)
b.congestionWindow = min(b.congestionWindow, b.maxCongestionWindow)
}
// Determines the appropriate window that constrains the in-flight during recovery.
func (b *bbrSender) calculateRecoveryWindow(bytesAcked, bytesLost congestion.ByteCount) {
if b.recoveryState == bbrRecoveryStateNotInRecovery {
return
}
// Set up the initial recovery window.
if b.recoveryWindow == 0 {
b.recoveryWindow = b.bytesInFlight + bytesAcked
b.recoveryWindow = max(b.minCongestionWindow, b.recoveryWindow)
return
}
// Remove losses from the recovery window, while accounting for a potential
// integer underflow.
if b.recoveryWindow >= bytesLost {
b.recoveryWindow = b.recoveryWindow - bytesLost
} else {
b.recoveryWindow = b.maxDatagramSize
}
// In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH,
// release additional |bytes_acked| to achieve a slow-start-like behavior.
if b.recoveryState == bbrRecoveryStateGrowth {
b.recoveryWindow += bytesAcked
}
// Always allow sending at least |bytes_acked| in response.
b.recoveryWindow = max(b.recoveryWindow, b.bytesInFlight+bytesAcked)
b.recoveryWindow = max(b.minCongestionWindow, b.recoveryWindow)
}
// Return whether we should exit STARTUP due to excessive loss.
func (b *bbrSender) shouldExitStartupDueToLoss(lastPacketSendState *sendTimeState) bool {
if b.numLossEventsInRound < defaultStartupFullLossCount || !lastPacketSendState.isValid {
return false
}
inflightAtSend := lastPacketSendState.bytesInFlight
if inflightAtSend > 0 && b.bytesLostInRound > 0 {
if b.bytesLostInRound > congestion.ByteCount(float64(inflightAtSend)*quicBbr2DefaultLossThreshold) {
return true
}
return false
}
return false
}
func bdpFromRttAndBandwidth(rtt time.Duration, bandwidth Bandwidth) congestion.ByteCount {
return congestion.ByteCount(rtt) * congestion.ByteCount(bandwidth) / congestion.ByteCount(BytesPerSecond) / congestion.ByteCount(time.Second)
}
func GetInitialPacketSize(addr net.Addr) congestion.ByteCount {
// If this is not a UDP address, we don't know anything about the MTU.
// Use the minimum size of an Initial packet as the max packet size.
if udpAddr, ok := addr.(*net.UDPAddr); ok {
if udpAddr.IP.To4() != nil {
return congestion.InitialPacketSizeIPv4
} else {
return congestion.InitialPacketSizeIPv6
}
} else {
return congestion.MinInitialPacketSize
}
}

View file

@ -0,0 +1,18 @@
package bbr
import "time"
// A Clock returns the current time
type Clock interface {
Now() time.Time
}
// DefaultClock implements the Clock interface using the Go stdlib clock.
type DefaultClock struct{}
var _ Clock = DefaultClock{}
// Now gets the current time
func (DefaultClock) Now() time.Time {
return time.Now()
}

View file

@ -0,0 +1,199 @@
package bbr
import (
"github.com/metacubex/quic-go/congestion"
)
// packetNumberIndexedQueue is a queue of mostly continuous numbered entries
// which supports the following operations:
// - adding elements to the end of the queue, or at some point past the end
// - removing elements in any order
// - retrieving elements
// If all elements are inserted in order, all of the operations above are
// amortized O(1) time.
//
// Internally, the data structure is a deque where each element is marked as
// present or not. The deque starts at the lowest present index. Whenever an
// element is removed, it's marked as not present, and the front of the deque is
// cleared of elements that are not present.
//
// The tail of the queue is not cleared due to the assumption of entries being
// inserted in order, though removing all elements of the queue will return it
// to its initial state.
//
// Note that this data structure is inherently hazardous, since an addition of
// just two entries will cause it to consume all of the memory available.
// Because of that, it is not a general-purpose container and should not be used
// as one.
type entryWrapper[T any] struct {
present bool
entry T
}
type packetNumberIndexedQueue[T any] struct {
entries RingBuffer[entryWrapper[T]]
numberOfPresentEntries int
firstPacket congestion.PacketNumber
}
func newPacketNumberIndexedQueue[T any](size int) *packetNumberIndexedQueue[T] {
q := &packetNumberIndexedQueue[T]{
firstPacket: invalidPacketNumber,
}
q.entries.Init(size)
return q
}
// Emplace inserts data associated |packet_number| into (or past) the end of the
// queue, filling up the missing intermediate entries as necessary. Returns
// true if the element has been inserted successfully, false if it was already
// in the queue or inserted out of order.
func (p *packetNumberIndexedQueue[T]) Emplace(packetNumber congestion.PacketNumber, entry *T) bool {
if packetNumber == invalidPacketNumber || entry == nil {
return false
}
if p.IsEmpty() {
p.entries.PushBack(entryWrapper[T]{
present: true,
entry: *entry,
})
p.numberOfPresentEntries = 1
p.firstPacket = packetNumber
return true
}
// Do not allow insertion out-of-order.
if packetNumber <= p.LastPacket() {
return false
}
// Handle potentially missing elements.
offset := int(packetNumber - p.FirstPacket())
if gap := offset - p.entries.Len(); gap > 0 {
for i := 0; i < gap; i++ {
p.entries.PushBack(entryWrapper[T]{})
}
}
p.entries.PushBack(entryWrapper[T]{
present: true,
entry: *entry,
})
p.numberOfPresentEntries++
return true
}
// GetEntry Retrieve the entry associated with the packet number. Returns the pointer
// to the entry in case of success, or nullptr if the entry does not exist.
func (p *packetNumberIndexedQueue[T]) GetEntry(packetNumber congestion.PacketNumber) *T {
ew := p.getEntryWraper(packetNumber)
if ew == nil {
return nil
}
return &ew.entry
}
// Remove, Same as above, but if an entry is present in the queue, also call f(entry)
// before removing it.
func (p *packetNumberIndexedQueue[T]) Remove(packetNumber congestion.PacketNumber, f func(T)) bool {
ew := p.getEntryWraper(packetNumber)
if ew == nil {
return false
}
if f != nil {
f(ew.entry)
}
ew.present = false
p.numberOfPresentEntries--
if packetNumber == p.FirstPacket() {
p.clearup()
}
return true
}
// RemoveUpTo, but not including |packet_number|.
// Unused slots in the front are also removed, which means when the function
// returns, |first_packet()| can be larger than |packet_number|.
func (p *packetNumberIndexedQueue[T]) RemoveUpTo(packetNumber congestion.PacketNumber) {
for !p.entries.Empty() &&
p.firstPacket != invalidPacketNumber &&
p.firstPacket < packetNumber {
if p.entries.Front().present {
p.numberOfPresentEntries--
}
p.entries.PopFront()
p.firstPacket++
}
p.clearup()
return
}
// IsEmpty return if queue is empty.
func (p *packetNumberIndexedQueue[T]) IsEmpty() bool {
return p.numberOfPresentEntries == 0
}
// NumberOfPresentEntries returns the number of entries in the queue.
func (p *packetNumberIndexedQueue[T]) NumberOfPresentEntries() int {
return p.numberOfPresentEntries
}
// EntrySlotsUsed returns the number of entries allocated in the underlying deque. This is
// proportional to the memory usage of the queue.
func (p *packetNumberIndexedQueue[T]) EntrySlotsUsed() int {
return p.entries.Len()
}
// LastPacket returns packet number of the first entry in the queue.
func (p *packetNumberIndexedQueue[T]) FirstPacket() (packetNumber congestion.PacketNumber) {
return p.firstPacket
}
// LastPacket returns packet number of the last entry ever inserted in the queue. Note that the
// entry in question may have already been removed. Zero if the queue is
// empty.
func (p *packetNumberIndexedQueue[T]) LastPacket() (packetNumber congestion.PacketNumber) {
if p.IsEmpty() {
return invalidPacketNumber
}
return p.firstPacket + congestion.PacketNumber(p.entries.Len()-1)
}
func (p *packetNumberIndexedQueue[T]) clearup() {
for !p.entries.Empty() && !p.entries.Front().present {
p.entries.PopFront()
p.firstPacket++
}
if p.entries.Empty() {
p.firstPacket = invalidPacketNumber
}
}
func (p *packetNumberIndexedQueue[T]) getEntryWraper(packetNumber congestion.PacketNumber) *entryWrapper[T] {
if packetNumber == invalidPacketNumber ||
p.IsEmpty() ||
packetNumber < p.firstPacket {
return nil
}
offset := int(packetNumber - p.firstPacket)
if offset >= p.entries.Len() {
return nil
}
ew := p.entries.Offset(offset)
if ew == nil || !ew.present {
return nil
}
return ew
}

View file

@ -0,0 +1,118 @@
package bbr
// A RingBuffer is a ring buffer.
// It acts as a heap that doesn't cause any allocations.
type RingBuffer[T any] struct {
ring []T
headPos, tailPos int
full bool
}
// Init preallocs a buffer with a certain size.
func (r *RingBuffer[T]) Init(size int) {
r.ring = make([]T, size)
}
// Len returns the number of elements in the ring buffer.
func (r *RingBuffer[T]) Len() int {
if r.full {
return len(r.ring)
}
if r.tailPos >= r.headPos {
return r.tailPos - r.headPos
}
return r.tailPos - r.headPos + len(r.ring)
}
// Empty says if the ring buffer is empty.
func (r *RingBuffer[T]) Empty() bool {
return !r.full && r.headPos == r.tailPos
}
// PushBack adds a new element.
// If the ring buffer is full, its capacity is increased first.
func (r *RingBuffer[T]) PushBack(t T) {
if r.full || len(r.ring) == 0 {
r.grow()
}
r.ring[r.tailPos] = t
r.tailPos++
if r.tailPos == len(r.ring) {
r.tailPos = 0
}
if r.tailPos == r.headPos {
r.full = true
}
}
// PopFront returns the next element.
// It must not be called when the buffer is empty, that means that
// callers might need to check if there are elements in the buffer first.
func (r *RingBuffer[T]) PopFront() T {
if r.Empty() {
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue")
}
r.full = false
t := r.ring[r.headPos]
r.ring[r.headPos] = *new(T)
r.headPos++
if r.headPos == len(r.ring) {
r.headPos = 0
}
return t
}
// Offset returns the offset element.
// It must not be called when the buffer is empty, that means that
// callers might need to check if there are elements in the buffer first
// and check if the index larger than buffer length.
func (r *RingBuffer[T]) Offset(index int) *T {
if r.Empty() || index >= r.Len() {
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: offset from invalid index")
}
offset := (r.headPos + index) % len(r.ring)
return &r.ring[offset]
}
// Front returns the front element.
// It must not be called when the buffer is empty, that means that
// callers might need to check if there are elements in the buffer first.
func (r *RingBuffer[T]) Front() *T {
if r.Empty() {
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: front from an empty queue")
}
return &r.ring[r.headPos]
}
// Back returns the back element.
// It must not be called when the buffer is empty, that means that
// callers might need to check if there are elements in the buffer first.
func (r *RingBuffer[T]) Back() *T {
if r.Empty() {
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: back from an empty queue")
}
return r.Offset(r.Len() - 1)
}
// Grow the maximum size of the queue.
// This method assume the queue is full.
func (r *RingBuffer[T]) grow() {
oldRing := r.ring
newSize := len(oldRing) * 2
if newSize == 0 {
newSize = 1
}
r.ring = make([]T, newSize)
headLen := copy(r.ring, oldRing[r.headPos:])
copy(r.ring[headLen:], oldRing[:r.headPos])
r.headPos, r.tailPos, r.full = 0, len(oldRing), false
}
// Clear removes all elements.
func (r *RingBuffer[T]) Clear() {
var zeroValue T
for i := range r.ring {
r.ring[i] = zeroValue
}
r.headPos, r.tailPos, r.full = 0, 0, false
}

View file

@ -0,0 +1,162 @@
package bbr
import (
"golang.org/x/exp/constraints"
)
// Implements Kathleen Nichols' algorithm for tracking the minimum (or maximum)
// estimate of a stream of samples over some fixed time interval. (E.g.,
// the minimum RTT over the past five minutes.) The algorithm keeps track of
// the best, second best, and third best min (or max) estimates, maintaining an
// invariant that the measurement time of the n'th best >= n-1'th best.
// The algorithm works as follows. On a reset, all three estimates are set to
// the same sample. The second best estimate is then recorded in the second
// quarter of the window, and a third best estimate is recorded in the second
// half of the window, bounding the worst case error when the true min is
// monotonically increasing (or true max is monotonically decreasing) over the
// window.
//
// A new best sample replaces all three estimates, since the new best is lower
// (or higher) than everything else in the window and it is the most recent.
// The window thus effectively gets reset on every new min. The same property
// holds true for second best and third best estimates. Specifically, when a
// sample arrives that is better than the second best but not better than the
// best, it replaces the second and third best estimates but not the best
// estimate. Similarly, a sample that is better than the third best estimate
// but not the other estimates replaces only the third best estimate.
//
// Finally, when the best expires, it is replaced by the second best, which in
// turn is replaced by the third best. The newest sample replaces the third
// best.
type WindowedFilterValue interface {
any
}
type WindowedFilterTime interface {
constraints.Integer | constraints.Float
}
type WindowedFilter[V WindowedFilterValue, T WindowedFilterTime] struct {
// Time length of window.
windowLength T
estimates []entry[V, T]
comparator func(V, V) int
}
type entry[V WindowedFilterValue, T WindowedFilterTime] struct {
sample V
time T
}
// Compares two values and returns true if the first is greater than or equal
// to the second.
func MaxFilter[O constraints.Ordered](a, b O) int {
if a > b {
return 1
} else if a < b {
return -1
}
return 0
}
// Compares two values and returns true if the first is less than or equal
// to the second.
func MinFilter[O constraints.Ordered](a, b O) int {
if a < b {
return 1
} else if a > b {
return -1
}
return 0
}
func NewWindowedFilter[V WindowedFilterValue, T WindowedFilterTime](windowLength T, comparator func(V, V) int) *WindowedFilter[V, T] {
return &WindowedFilter[V, T]{
windowLength: windowLength,
estimates: make([]entry[V, T], 3, 3),
comparator: comparator,
}
}
// Changes the window length. Does not update any current samples.
func (f *WindowedFilter[V, T]) SetWindowLength(windowLength T) {
f.windowLength = windowLength
}
func (f *WindowedFilter[V, T]) GetBest() V {
return f.estimates[0].sample
}
func (f *WindowedFilter[V, T]) GetSecondBest() V {
return f.estimates[1].sample
}
func (f *WindowedFilter[V, T]) GetThirdBest() V {
return f.estimates[2].sample
}
// Updates best estimates with |sample|, and expires and updates best
// estimates as necessary.
func (f *WindowedFilter[V, T]) Update(newSample V, newTime T) {
// Reset all estimates if they have not yet been initialized, if new sample
// is a new best, or if the newest recorded estimate is too old.
if f.comparator(f.estimates[0].sample, *new(V)) == 0 ||
f.comparator(newSample, f.estimates[0].sample) >= 0 ||
newTime-f.estimates[2].time > f.windowLength {
f.Reset(newSample, newTime)
return
}
if f.comparator(newSample, f.estimates[1].sample) >= 0 {
f.estimates[1] = entry[V, T]{newSample, newTime}
f.estimates[2] = f.estimates[1]
} else if f.comparator(newSample, f.estimates[2].sample) >= 0 {
f.estimates[2] = entry[V, T]{newSample, newTime}
}
// Expire and update estimates as necessary.
if newTime-f.estimates[0].time > f.windowLength {
// The best estimate hasn't been updated for an entire window, so promote
// second and third best estimates.
f.estimates[0] = f.estimates[1]
f.estimates[1] = f.estimates[2]
f.estimates[2] = entry[V, T]{newSample, newTime}
// Need to iterate one more time. Check if the new best estimate is
// outside the window as well, since it may also have been recorded a
// long time ago. Don't need to iterate once more since we cover that
// case at the beginning of the method.
if newTime-f.estimates[0].time > f.windowLength {
f.estimates[0] = f.estimates[1]
f.estimates[1] = f.estimates[2]
}
return
}
if f.comparator(f.estimates[1].sample, f.estimates[0].sample) == 0 &&
newTime-f.estimates[1].time > f.windowLength/4 {
// A quarter of the window has passed without a better sample, so the
// second-best estimate is taken from the second quarter of the window.
f.estimates[1] = entry[V, T]{newSample, newTime}
f.estimates[2] = f.estimates[1]
return
}
if f.comparator(f.estimates[2].sample, f.estimates[1].sample) == 0 &&
newTime-f.estimates[2].time > f.windowLength/2 {
// We've passed a half of the window without a better estimate, so take
// a third-best estimate from the second half of the window.
f.estimates[2] = entry[V, T]{newSample, newTime}
}
}
// Resets all estimates to new sample.
func (f *WindowedFilter[V, T]) Reset(newSample V, newTime T) {
f.estimates[2] = entry[V, T]{newSample, newTime}
f.estimates[1] = f.estimates[2]
f.estimates[0] = f.estimates[1]
}
func (f *WindowedFilter[V, T]) Clear() {
f.estimates = make([]entry[V, T], 3, 3)
}

View file

@ -0,0 +1,181 @@
package brutal
import (
"fmt"
"os"
"strconv"
"time"
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion/common"
"github.com/metacubex/quic-go/congestion"
)
const (
pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample
minSampleCount = 50
minAckRate = 0.8
congestionWindowMultiplier = 2
debugEnv = "HYSTERIA_BRUTAL_DEBUG"
debugPrintInterval = 2
)
var _ congestion.CongestionControl = &BrutalSender{}
type BrutalSender struct {
rttStats congestion.RTTStatsProvider
bps congestion.ByteCount
maxDatagramSize congestion.ByteCount
pacer *common.Pacer
pktInfoSlots [pktInfoSlotCount]pktInfo
ackRate float64
debug bool
lastAckPrintTimestamp int64
}
type pktInfo struct {
Timestamp int64
AckCount uint64
LossCount uint64
}
func NewBrutalSender(bps uint64) *BrutalSender {
debug, _ := strconv.ParseBool(os.Getenv(debugEnv))
bs := &BrutalSender{
bps: congestion.ByteCount(bps),
maxDatagramSize: congestion.InitialPacketSizeIPv4,
ackRate: 1,
debug: debug,
}
bs.pacer = common.NewPacer(func() congestion.ByteCount {
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
})
return bs
}
func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
b.rttStats = rttStats
}
func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
return b.pacer.TimeUntilSend()
}
func (b *BrutalSender) HasPacingBudget(now time.Time) bool {
return b.pacer.Budget(now) >= b.maxDatagramSize
}
func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
return bytesInFlight < b.GetCongestionWindow()
}
func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
rtt := b.rttStats.SmoothedRTT()
if rtt <= 0 {
return 10240
}
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate)
}
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool,
) {
b.pacer.SentPacket(sentTime, bytes)
}
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
priorInFlight congestion.ByteCount, eventTime time.Time,
) {
// Stub
}
func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount,
priorInFlight congestion.ByteCount,
) {
// Stub
}
func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) {
currentTimestamp := eventTime.Unix()
slot := currentTimestamp % pktInfoSlotCount
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets))
b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets))
} else {
// uninitialized slot or too old, reset
b.pktInfoSlots[slot].Timestamp = currentTimestamp
b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets))
b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets))
}
b.updateAckRate(currentTimestamp)
}
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
b.maxDatagramSize = size
b.pacer.SetMaxDatagramSize(size)
if b.debug {
b.debugPrint("SetMaxDatagramSize: %d", size)
}
}
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
minTimestamp := currentTimestamp - pktInfoSlotCount
var ackCount, lossCount uint64
for _, info := range b.pktInfoSlots {
if info.Timestamp < minTimestamp {
continue
}
ackCount += info.AckCount
lossCount += info.LossCount
}
if ackCount+lossCount < minSampleCount {
b.ackRate = 1
if b.canPrintAckRate(currentTimestamp) {
b.lastAckPrintTimestamp = currentTimestamp
b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)",
ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
}
return
}
rate := float64(ackCount) / float64(ackCount+lossCount)
if rate < minAckRate {
b.ackRate = minAckRate
if b.canPrintAckRate(currentTimestamp) {
b.lastAckPrintTimestamp = currentTimestamp
b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
}
return
}
b.ackRate = rate
if b.canPrintAckRate(currentTimestamp) {
b.lastAckPrintTimestamp = currentTimestamp
b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
}
}
func (b *BrutalSender) InSlowStart() bool {
return false
}
func (b *BrutalSender) InRecovery() bool {
return false
}
func (b *BrutalSender) MaybeExitSlowStart() {}
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool {
return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval
}
func (b *BrutalSender) debugPrint(format string, a ...any) {
fmt.Printf("[BrutalSender] [%s] %s\n",
time.Now().Format("15:04:05"),
fmt.Sprintf(format, a...))
}

View file

@ -0,0 +1,95 @@
package common
import (
"math"
"time"
"github.com/metacubex/quic-go/congestion"
)
const (
maxBurstPackets = 10
)
// Pacer implements a token bucket pacing algorithm.
type Pacer struct {
budgetAtLastSent congestion.ByteCount
maxDatagramSize congestion.ByteCount
lastSentTime time.Time
getBandwidth func() congestion.ByteCount // in bytes/s
}
func NewPacer(getBandwidth func() congestion.ByteCount) *Pacer {
p := &Pacer{
budgetAtLastSent: maxBurstPackets * congestion.InitialPacketSizeIPv4,
maxDatagramSize: congestion.InitialPacketSizeIPv4,
getBandwidth: getBandwidth,
}
return p
}
func (p *Pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
budget := p.Budget(sendTime)
if size > budget {
p.budgetAtLastSent = 0
} else {
p.budgetAtLastSent = budget - size
}
p.lastSentTime = sendTime
}
func (p *Pacer) Budget(now time.Time) congestion.ByteCount {
if p.lastSentTime.IsZero() {
return p.maxBurstSize()
}
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
if budget < 0 { // protect against overflows
budget = congestion.ByteCount(1<<62 - 1)
}
return minByteCount(p.maxBurstSize(), budget)
}
func (p *Pacer) maxBurstSize() congestion.ByteCount {
return maxByteCount(
congestion.ByteCount((congestion.MinPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9,
maxBurstPackets*p.maxDatagramSize,
)
}
// TimeUntilSend returns when the next packet should be sent.
// It returns the zero value of time.Time if a packet can be sent immediately.
func (p *Pacer) TimeUntilSend() time.Time {
if p.budgetAtLastSent >= p.maxDatagramSize {
return time.Time{}
}
return p.lastSentTime.Add(maxDuration(
congestion.MinPacingDelay,
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/
float64(p.getBandwidth())))*time.Nanosecond,
))
}
func (p *Pacer) SetMaxDatagramSize(s congestion.ByteCount) {
p.maxDatagramSize = s
}
func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
if a < b {
return b
}
return a
}
func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
if a < b {
return a
}
return b
}
func maxDuration(a, b time.Duration) time.Duration {
if a > b {
return a
}
return b
}

View file

@ -0,0 +1,18 @@
package congestion
import (
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion/bbr"
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion/brutal"
"github.com/metacubex/quic-go"
)
func UseBBR(conn quic.Connection) {
conn.SetCongestionControl(bbr.NewBbrSender(
bbr.DefaultClock{},
bbr.GetInitialPacketSize(conn.RemoteAddr()),
))
}
func UseBrutal(conn quic.Connection, tx uint64) {
conn.SetCongestionControl(brutal.NewBrutalSender(tx))
}

View file

@ -0,0 +1,77 @@
package frag
import (
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/protocol"
)
func FragUDPMessage(m *protocol.UDPMessage, maxSize int) []protocol.UDPMessage {
if m.Size() <= maxSize {
return []protocol.UDPMessage{*m}
}
fullPayload := m.Data
maxPayloadSize := maxSize - m.HeaderSize()
off := 0
fragID := uint8(0)
fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up
frags := make([]protocol.UDPMessage, fragCount)
for off < len(fullPayload) {
payloadSize := len(fullPayload) - off
if payloadSize > maxPayloadSize {
payloadSize = maxPayloadSize
}
frag := *m
frag.FragID = fragID
frag.FragCount = fragCount
frag.Data = fullPayload[off : off+payloadSize]
frags[fragID] = frag
off += payloadSize
fragID++
}
return frags
}
// Defragger handles the defragmentation of UDP messages.
// The current implementation can only handle one packet ID at a time.
// If another packet arrives before a packet has received all fragments
// in their entirety, any previous state is discarded.
type Defragger struct {
pktID uint16
frags []*protocol.UDPMessage
count uint8
size int // data size
}
func (d *Defragger) Feed(m *protocol.UDPMessage) *protocol.UDPMessage {
if m.FragCount <= 1 {
return m
}
if m.FragID >= m.FragCount {
// wtf is this?
return nil
}
if m.PacketID != d.pktID || m.FragCount != uint8(len(d.frags)) {
// new message, clear previous state
d.pktID = m.PacketID
d.frags = make([]*protocol.UDPMessage, m.FragCount)
d.frags[m.FragID] = m
d.count = 1
d.size = len(m.Data)
} else if d.frags[m.FragID] == nil {
d.frags[m.FragID] = m
d.count++
d.size += len(m.Data)
if int(d.count) == len(d.frags) {
// all fragments received, assemble
data := make([]byte, d.size)
off := 0
for _, frag := range d.frags {
off += copy(data[off:], frag.Data)
}
m.Data = data
m.FragID = 0
m.FragCount = 1
return m
}
}
return nil
}

View file

@ -0,0 +1,7 @@
//go:build linux || windows || darwin
package pmtud
const (
DisablePathMTUDiscovery = false
)

View file

@ -0,0 +1,13 @@
//go:build !linux && !windows && !darwin
package pmtud
// quic-go's MTU detection is enabled by default on all platforms.
// However, it only actually sets the DF bit on 3 supported platforms (Windows, macOS, Linux).
// As a result, on other platforms, probe packets that should never be fragmented will still
// be fragmented and transmitted. So we're only enabling it for platforms where we've verified
// its functionality for now.
const (
DisablePathMTUDiscovery = true
)

View file

@ -0,0 +1,68 @@
package protocol
import (
"net/http"
"strconv"
)
const (
URLHost = "hysteria"
URLPath = "/auth"
RequestHeaderAuth = "Hysteria-Auth"
ResponseHeaderUDPEnabled = "Hysteria-UDP"
CommonHeaderCCRX = "Hysteria-CC-RX"
CommonHeaderPadding = "Hysteria-Padding"
StatusAuthOK = 233
)
// AuthRequest is what client sends to server for authentication.
type AuthRequest struct {
Auth string
Rx uint64 // 0 = unknown, client asks server to use bandwidth detection
}
// AuthResponse is what server sends to client when authentication is passed.
type AuthResponse struct {
UDPEnabled bool
Rx uint64 // 0 = unlimited
RxAuto bool // true = server asks client to use bandwidth detection
}
func AuthRequestFromHeader(h http.Header) AuthRequest {
rx, _ := strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64)
return AuthRequest{
Auth: h.Get(RequestHeaderAuth),
Rx: rx,
}
}
func AuthRequestToHeader(h http.Header, req AuthRequest) {
h.Set(RequestHeaderAuth, req.Auth)
h.Set(CommonHeaderCCRX, strconv.FormatUint(req.Rx, 10))
h.Set(CommonHeaderPadding, authRequestPadding.String())
}
func AuthResponseFromHeader(h http.Header) AuthResponse {
resp := AuthResponse{}
resp.UDPEnabled, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled))
rxStr := h.Get(CommonHeaderCCRX)
if rxStr == "auto" {
// Special case for server requesting client to use bandwidth detection
resp.RxAuto = true
} else {
resp.Rx, _ = strconv.ParseUint(rxStr, 10, 64)
}
return resp
}
func AuthResponseToHeader(h http.Header, resp AuthResponse) {
h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(resp.UDPEnabled))
if resp.RxAuto {
h.Set(CommonHeaderCCRX, "auto")
} else {
h.Set(CommonHeaderCCRX, strconv.FormatUint(resp.Rx, 10))
}
h.Set(CommonHeaderPadding, authResponsePadding.String())
}

View file

@ -0,0 +1,31 @@
package protocol
import (
"math/rand"
)
const (
paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)
// padding specifies a half-open range [Min, Max).
type padding struct {
Min int
Max int
}
func (p padding) String() string {
n := p.Min + rand.Intn(p.Max-p.Min)
bs := make([]byte, n)
for i := range bs {
bs[i] = paddingChars[rand.Intn(len(paddingChars))]
}
return string(bs)
}
var (
authRequestPadding = padding{Min: 256, Max: 2048}
authResponsePadding = padding{Min: 256, Max: 2048}
tcpRequestPadding = padding{Min: 64, Max: 512}
tcpResponsePadding = padding{Min: 128, Max: 1024}
)

View file

@ -0,0 +1,255 @@
package protocol
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"github.com/metacubex/mihomo/transport/hysteria2/core/errors"
"github.com/metacubex/quic-go/quicvarint"
)
const (
FrameTypeTCPRequest = 0x401
// Max length values are for preventing DoS attacks
MaxAddressLength = 2048
MaxMessageLength = 2048
MaxPaddingLength = 4096
MaxUDPSize = 4096
maxVarInt1 = 63
maxVarInt2 = 16383
maxVarInt4 = 1073741823
maxVarInt8 = 4611686018427387903
)
// TCPRequest format:
// 0x401 (QUIC varint)
// Address length (QUIC varint)
// Address (bytes)
// Padding length (QUIC varint)
// Padding (bytes)
func ReadTCPRequest(r io.Reader) (string, error) {
bReader := quicvarint.NewReader(r)
addrLen, err := quicvarint.Read(bReader)
if err != nil {
return "", err
}
if addrLen == 0 || addrLen > MaxAddressLength {
return "", errors.ProtocolError{Message: "invalid address length"}
}
addrBuf := make([]byte, addrLen)
_, err = io.ReadFull(r, addrBuf)
if err != nil {
return "", err
}
paddingLen, err := quicvarint.Read(bReader)
if err != nil {
return "", err
}
if paddingLen > MaxPaddingLength {
return "", errors.ProtocolError{Message: "invalid padding length"}
}
if paddingLen > 0 {
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
if err != nil {
return "", err
}
}
return string(addrBuf), nil
}
func WriteTCPRequest(w io.Writer, addr string) error {
padding := tcpRequestPadding.String()
paddingLen := len(padding)
addrLen := len(addr)
sz := int(quicvarint.Len(FrameTypeTCPRequest)) +
int(quicvarint.Len(uint64(addrLen))) + addrLen +
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
buf := make([]byte, sz)
i := varintPut(buf, FrameTypeTCPRequest)
i += varintPut(buf[i:], uint64(addrLen))
i += copy(buf[i:], addr)
i += varintPut(buf[i:], uint64(paddingLen))
copy(buf[i:], padding)
_, err := w.Write(buf)
return err
}
// TCPResponse format:
// Status (byte, 0=ok, 1=error)
// Message length (QUIC varint)
// Message (bytes)
// Padding length (QUIC varint)
// Padding (bytes)
func ReadTCPResponse(r io.Reader) (bool, string, error) {
var status [1]byte
if _, err := io.ReadFull(r, status[:]); err != nil {
return false, "", err
}
bReader := quicvarint.NewReader(r)
msgLen, err := quicvarint.Read(bReader)
if err != nil {
return false, "", err
}
if msgLen > MaxMessageLength {
return false, "", errors.ProtocolError{Message: "invalid message length"}
}
var msgBuf []byte
// No message is fine
if msgLen > 0 {
msgBuf = make([]byte, msgLen)
_, err = io.ReadFull(r, msgBuf)
if err != nil {
return false, "", err
}
}
paddingLen, err := quicvarint.Read(bReader)
if err != nil {
return false, "", err
}
if paddingLen > MaxPaddingLength {
return false, "", errors.ProtocolError{Message: "invalid padding length"}
}
if paddingLen > 0 {
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
if err != nil {
return false, "", err
}
}
return status[0] == 0, string(msgBuf), nil
}
func WriteTCPResponse(w io.Writer, ok bool, msg string) error {
padding := tcpResponsePadding.String()
paddingLen := len(padding)
msgLen := len(msg)
sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
buf := make([]byte, sz)
if ok {
buf[0] = 0
} else {
buf[0] = 1
}
i := varintPut(buf[1:], uint64(msgLen))
i += copy(buf[1+i:], msg)
i += varintPut(buf[1+i:], uint64(paddingLen))
copy(buf[1+i:], padding)
_, err := w.Write(buf)
return err
}
// UDPMessage format:
// Session ID (uint32 BE)
// Packet ID (uint16 BE)
// Fragment ID (uint8)
// Fragment count (uint8)
// Address length (QUIC varint)
// Address (bytes)
// Data...
type UDPMessage struct {
SessionID uint32 // 4
PacketID uint16 // 2
FragID uint8 // 1
FragCount uint8 // 1
Addr string // varint + bytes
Data []byte
}
func (m *UDPMessage) HeaderSize() int {
lAddr := len(m.Addr)
return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr
}
func (m *UDPMessage) Size() int {
return m.HeaderSize() + len(m.Data)
}
func (m *UDPMessage) Serialize(buf []byte) int {
// Make sure the buffer is big enough
if len(buf) < m.Size() {
return -1
}
binary.BigEndian.PutUint32(buf, m.SessionID)
binary.BigEndian.PutUint16(buf[4:], m.PacketID)
buf[6] = m.FragID
buf[7] = m.FragCount
i := varintPut(buf[8:], uint64(len(m.Addr)))
i += copy(buf[8+i:], m.Addr)
i += copy(buf[8+i:], m.Data)
return 8 + i
}
func ParseUDPMessage(msg []byte) (*UDPMessage, error) {
m := &UDPMessage{}
buf := bytes.NewBuffer(msg)
if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil {
return nil, err
}
if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil {
return nil, err
}
if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil {
return nil, err
}
if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil {
return nil, err
}
lAddr, err := quicvarint.Read(buf)
if err != nil {
return nil, err
}
if lAddr == 0 || lAddr > MaxMessageLength {
return nil, errors.ProtocolError{Message: "invalid address length"}
}
bs := buf.Bytes()
if len(bs) <= int(lAddr) {
// We use <= instead of < here as we expect at least one byte of data after the address
return nil, errors.ProtocolError{Message: "invalid message length"}
}
m.Addr = string(bs[:lAddr])
m.Data = bs[lAddr:]
return m, nil
}
// varintPut is like quicvarint.Append, but instead of appending to a slice,
// it writes to a fixed-size buffer. Returns the number of bytes written.
func varintPut(b []byte, i uint64) int {
if i <= maxVarInt1 {
b[0] = uint8(i)
return 1
}
if i <= maxVarInt2 {
b[0] = uint8(i>>8) | 0x40
b[1] = uint8(i)
return 2
}
if i <= maxVarInt4 {
b[0] = uint8(i>>24) | 0x80
b[1] = uint8(i >> 16)
b[2] = uint8(i >> 8)
b[3] = uint8(i)
return 4
}
if i <= maxVarInt8 {
b[0] = uint8(i>>56) | 0xc0
b[1] = uint8(i >> 48)
b[2] = uint8(i >> 40)
b[3] = uint8(i >> 32)
b[4] = uint8(i >> 24)
b[5] = uint8(i >> 16)
b[6] = uint8(i >> 8)
b[7] = uint8(i)
return 8
}
panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
}

View file

@ -0,0 +1,24 @@
package utils
import (
"sync/atomic"
"time"
)
type AtomicTime struct {
v atomic.Value
}
func NewAtomicTime(t time.Time) *AtomicTime {
a := &AtomicTime{}
a.Set(t)
return a
}
func (t *AtomicTime) Set(new time.Time) {
t.v.Store(new)
}
func (t *AtomicTime) Get() time.Time {
return t.v.Load().(time.Time)
}

View file

@ -0,0 +1,62 @@
package utils
import (
"context"
"time"
"github.com/metacubex/quic-go"
)
// QStream is a wrapper of quic.Stream that handles Close() in a way that
// makes more sense to us. By default, quic.Stream's Close() only closes
// the write side of the stream, not the read side. And if there is unread
// data, the stream is not really considered closed until either the data
// is drained or CancelRead() is called.
// References:
// - https://github.com/libp2p/go-libp2p/blob/master/p2p/transport/quic/stream.go
// - https://github.com/quic-go/quic-go/issues/3558
// - https://github.com/quic-go/quic-go/issues/1599
type QStream struct {
Stream quic.Stream
}
func (s *QStream) StreamID() quic.StreamID {
return s.Stream.StreamID()
}
func (s *QStream) Read(p []byte) (n int, err error) {
return s.Stream.Read(p)
}
func (s *QStream) CancelRead(code quic.StreamErrorCode) {
s.Stream.CancelRead(code)
}
func (s *QStream) SetReadDeadline(t time.Time) error {
return s.Stream.SetReadDeadline(t)
}
func (s *QStream) Write(p []byte) (n int, err error) {
return s.Stream.Write(p)
}
func (s *QStream) Close() error {
s.Stream.CancelRead(0)
return s.Stream.Close()
}
func (s *QStream) CancelWrite(code quic.StreamErrorCode) {
s.Stream.CancelWrite(code)
}
func (s *QStream) Context() context.Context {
return s.Stream.Context()
}
func (s *QStream) SetWriteDeadline(t time.Time) error {
return s.Stream.SetWriteDeadline(t)
}
func (s *QStream) SetDeadline(t time.Time) error {
return s.Stream.SetDeadline(t)
}

View file

@ -0,0 +1,92 @@
package correctnet
import (
"net"
"net/http"
"strings"
)
func extractIPFamily(ip net.IP) (family string) {
if len(ip) == 0 {
// real family independent wildcard address, such as ":443"
return ""
}
if p4 := ip.To4(); len(p4) == net.IPv4len {
return "4"
}
return "6"
}
func tcpAddrNetwork(addr *net.TCPAddr) (network string) {
if addr == nil {
return "tcp"
}
return "tcp" + extractIPFamily(addr.IP)
}
func udpAddrNetwork(addr *net.UDPAddr) (network string) {
if addr == nil {
return "udp"
}
return "udp" + extractIPFamily(addr.IP)
}
func ipAddrNetwork(addr *net.IPAddr) (network string) {
if addr == nil {
return "ip"
}
return "ip" + extractIPFamily(addr.IP)
}
func Listen(network, address string) (net.Listener, error) {
if network == "tcp" {
tcpAddr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return nil, err
}
return ListenTCP(network, tcpAddr)
}
return net.Listen(network, address)
}
func ListenTCP(network string, laddr *net.TCPAddr) (*net.TCPListener, error) {
if network == "tcp" {
return net.ListenTCP(tcpAddrNetwork(laddr), laddr)
}
return net.ListenTCP(network, laddr)
}
func ListenPacket(network, address string) (listener net.PacketConn, err error) {
if network == "udp" {
udpAddr, err := net.ResolveUDPAddr(network, address)
if err != nil {
return nil, err
}
return ListenUDP(network, udpAddr)
}
if strings.HasPrefix(network, "ip:") {
proto := network[3:]
ipAddr, err := net.ResolveIPAddr(proto, address)
if err != nil {
return nil, err
}
return net.ListenIP(ipAddrNetwork(ipAddr)+":"+proto, ipAddr)
}
return net.ListenPacket(network, address)
}
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
if network == "udp" {
return net.ListenUDP(udpAddrNetwork(laddr), laddr)
}
return net.ListenUDP(network, laddr)
}
func HTTPListenAndServe(address string, handler http.Handler) error {
listener, err := Listen("tcp", address)
if err != nil {
return err
}
defer listener.Close()
return http.Serve(listener, handler)
}

View file

@ -0,0 +1,121 @@
package obfs
import (
"net"
"sync"
"syscall"
"time"
)
const udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough
// Obfuscator is the interface that wraps the Obfuscate and Deobfuscate methods.
// Both methods return the number of bytes written to out.
// If a packet is not valid, the methods should return 0.
type Obfuscator interface {
Obfuscate(in, out []byte) int
Deobfuscate(in, out []byte) int
}
var _ net.PacketConn = (*obfsPacketConn)(nil)
type obfsPacketConn struct {
Conn net.PacketConn
Obfs Obfuscator
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
// obfsPacketConnUDP is a special case of obfsPacketConn that uses a UDPConn
// as the underlying connection. We pass additional methods to quic-go to
// enable UDP-specific optimizations.
type obfsPacketConnUDP struct {
*obfsPacketConn
UDPConn *net.UDPConn
}
// WrapPacketConn enables obfuscation on a net.PacketConn.
// The obfuscation is transparent to the caller - the n bytes returned by
// ReadFrom and WriteTo are the number of original bytes, not after
// obfuscation/deobfuscation.
func WrapPacketConn(conn net.PacketConn, obfs Obfuscator) net.PacketConn {
opc := &obfsPacketConn{
Conn: conn,
Obfs: obfs,
readBuf: make([]byte, udpBufferSize),
writeBuf: make([]byte, udpBufferSize),
}
if udpConn, ok := conn.(*net.UDPConn); ok {
return &obfsPacketConnUDP{
obfsPacketConn: opc,
UDPConn: udpConn,
}
} else {
return opc
}
}
func (c *obfsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
for {
c.readMutex.Lock()
n, addr, err = c.Conn.ReadFrom(c.readBuf)
if n <= 0 {
c.readMutex.Unlock()
return
}
n = c.Obfs.Deobfuscate(c.readBuf[:n], p)
c.readMutex.Unlock()
if n > 0 || err != nil {
return
}
// Invalid packet, try again
}
}
func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
c.writeMutex.Lock()
nn := c.Obfs.Obfuscate(p, c.writeBuf)
_, err = c.Conn.WriteTo(c.writeBuf[:nn], addr)
c.writeMutex.Unlock()
if err == nil {
n = len(p)
}
return
}
func (c *obfsPacketConn) Close() error {
return c.Conn.Close()
}
func (c *obfsPacketConn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
}
func (c *obfsPacketConn) SetDeadline(t time.Time) error {
return c.Conn.SetDeadline(t)
}
func (c *obfsPacketConn) SetReadDeadline(t time.Time) error {
return c.Conn.SetReadDeadline(t)
}
func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error {
return c.Conn.SetWriteDeadline(t)
}
// UDP-specific methods below
func (c *obfsPacketConnUDP) SetReadBuffer(bytes int) error {
return c.UDPConn.SetReadBuffer(bytes)
}
func (c *obfsPacketConnUDP) SetWriteBuffer(bytes int) error {
return c.UDPConn.SetWriteBuffer(bytes)
}
func (c *obfsPacketConnUDP) SyscallConn() (syscall.RawConn, error) {
return c.UDPConn.SyscallConn()
}

View file

@ -0,0 +1,71 @@
package obfs
import (
"fmt"
"math/rand"
"sync"
"time"
"golang.org/x/crypto/blake2b"
)
const (
smPSKMinLen = 4
smSaltLen = 8
smKeyLen = blake2b.Size256
)
var _ Obfuscator = (*SalamanderObfuscator)(nil)
var ErrPSKTooShort = fmt.Errorf("PSK must be at least %d bytes", smPSKMinLen)
// SalamanderObfuscator is an obfuscator that obfuscates each packet with
// the BLAKE2b-256 hash of a pre-shared key combined with a random salt.
// Packet format: [8-byte salt][payload]
type SalamanderObfuscator struct {
PSK []byte
RandSrc *rand.Rand
lk sync.Mutex
}
func NewSalamanderObfuscator(psk []byte) (*SalamanderObfuscator, error) {
if len(psk) < smPSKMinLen {
return nil, ErrPSKTooShort
}
return &SalamanderObfuscator{
PSK: psk,
RandSrc: rand.New(rand.NewSource(time.Now().UnixNano())),
}, nil
}
func (o *SalamanderObfuscator) Obfuscate(in, out []byte) int {
outLen := len(in) + smSaltLen
if len(out) < outLen {
return 0
}
o.lk.Lock()
_, _ = o.RandSrc.Read(out[:smSaltLen])
o.lk.Unlock()
key := o.key(out[:smSaltLen])
for i, c := range in {
out[i+smSaltLen] = c ^ key[i%smKeyLen]
}
return outLen
}
func (o *SalamanderObfuscator) Deobfuscate(in, out []byte) int {
outLen := len(in) - smSaltLen
if outLen <= 0 || len(out) < outLen {
return 0
}
key := o.key(in[:smSaltLen])
for i, c := range in[smSaltLen:] {
out[i] = c ^ key[i%smKeyLen]
}
return outLen
}
func (o *SalamanderObfuscator) key(salt []byte) [smKeyLen]byte {
return blake2b.Sum256(append(o.PSK, salt...))
}

View file

@ -0,0 +1,45 @@
package obfs
import (
"crypto/rand"
"testing"
"github.com/stretchr/testify/assert"
)
func BenchmarkSalamanderObfuscator_Obfuscate(b *testing.B) {
o, _ := NewSalamanderObfuscator([]byte("average_password"))
in := make([]byte, 1200)
_, _ = rand.Read(in)
out := make([]byte, 2048)
b.ResetTimer()
for i := 0; i < b.N; i++ {
o.Obfuscate(in, out)
}
}
func BenchmarkSalamanderObfuscator_Deobfuscate(b *testing.B) {
o, _ := NewSalamanderObfuscator([]byte("average_password"))
in := make([]byte, 1200)
_, _ = rand.Read(in)
out := make([]byte, 2048)
b.ResetTimer()
for i := 0; i < b.N; i++ {
o.Deobfuscate(in, out)
}
}
func TestSalamanderObfuscator(t *testing.T) {
o, _ := NewSalamanderObfuscator([]byte("average_password"))
in := make([]byte, 1200)
oOut := make([]byte, 2048)
dOut := make([]byte, 2048)
for i := 0; i < 1000; i++ {
_, _ = rand.Read(in)
n := o.Obfuscate(in, oOut)
assert.Equal(t, len(in)+smSaltLen, n)
n = o.Deobfuscate(oOut[:n], dOut)
assert.Equal(t, len(in), n)
assert.Equal(t, in, dOut[:n])
}
}

View file

@ -0,0 +1,92 @@
package udphop
import (
"fmt"
"net"
"strconv"
"strings"
)
type InvalidPortError struct {
PortStr string
}
func (e InvalidPortError) Error() string {
return fmt.Sprintf("%s is not a valid port number or range", e.PortStr)
}
// UDPHopAddr contains an IP address and a list of ports.
type UDPHopAddr struct {
IP net.IP
Ports []uint16
PortStr string
}
func (a *UDPHopAddr) Network() string {
return "udphop"
}
func (a *UDPHopAddr) String() string {
return net.JoinHostPort(a.IP.String(), a.PortStr)
}
// addrs returns a list of net.Addr's, one for each port.
func (a *UDPHopAddr) addrs() ([]net.Addr, error) {
var addrs []net.Addr
for _, port := range a.Ports {
addr := &net.UDPAddr{
IP: a.IP,
Port: int(port),
}
addrs = append(addrs, addr)
}
return addrs, nil
}
func ResolveUDPHopAddr(addr string) (*UDPHopAddr, error) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
ip, err := net.ResolveIPAddr("ip", host)
if err != nil {
return nil, err
}
result := &UDPHopAddr{
IP: ip.IP,
PortStr: portStr,
}
portStrs := strings.Split(portStr, ",")
for _, portStr := range portStrs {
if strings.Contains(portStr, "-") {
// Port range
portRange := strings.Split(portStr, "-")
if len(portRange) != 2 {
return nil, InvalidPortError{portStr}
}
start, err := strconv.ParseUint(portRange[0], 10, 16)
if err != nil {
return nil, InvalidPortError{portStr}
}
end, err := strconv.ParseUint(portRange[1], 10, 16)
if err != nil {
return nil, InvalidPortError{portStr}
}
if start > end {
start, end = end, start
}
for i := start; i <= end; i++ {
result.Ports = append(result.Ports, uint16(i))
}
} else {
// Single port
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, InvalidPortError{portStr}
}
result.Ports = append(result.Ports, uint16(port))
}
}
return result, nil
}

View file

@ -0,0 +1,288 @@
package udphop
import (
"errors"
"math/rand"
"net"
"sync"
"syscall"
"time"
"github.com/metacubex/mihomo/log"
)
const (
packetQueueSize = 1024
udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough
defaultHopInterval = 30 * time.Second
)
type udpHopPacketConn struct {
Addr net.Addr
Addrs []net.Addr
HopInterval time.Duration
connMutex sync.RWMutex
prevConn net.PacketConn
currentConn net.PacketConn
addrIndex int
readBufferSize int
writeBufferSize int
recvQueue chan *udpPacket
closeChan chan struct{}
closed bool
bufPool sync.Pool
}
type udpPacket struct {
Buf []byte
N int
Addr net.Addr
Err error
}
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration) (net.PacketConn, error) {
if hopInterval == 0 {
hopInterval = defaultHopInterval
} else if hopInterval < 5*time.Second {
return nil, errors.New("hop interval must be at least 5 seconds")
}
addrs, err := addr.addrs()
if err != nil {
return nil, err
}
curConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
hConn := &udpHopPacketConn{
Addr: addr,
Addrs: addrs,
HopInterval: hopInterval,
prevConn: nil,
currentConn: curConn,
addrIndex: rand.Intn(len(addrs)),
recvQueue: make(chan *udpPacket, packetQueueSize),
closeChan: make(chan struct{}),
bufPool: sync.Pool{
New: func() interface{} {
return make([]byte, udpBufferSize)
},
},
}
go hConn.recvLoop(curConn)
go hConn.hopLoop()
return hConn, nil
}
func (u *udpHopPacketConn) recvLoop(conn net.PacketConn) {
for {
buf := u.bufPool.Get().([]byte)
n, addr, err := conn.ReadFrom(buf)
if err != nil {
u.bufPool.Put(buf)
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
// Only pass through timeout errors here, not permanent errors
// like connection closed. Connection close is normal as we close
// the old connection to exit this loop every time we hop.
u.recvQueue <- &udpPacket{nil, 0, nil, netErr}
}
return
}
select {
case u.recvQueue <- &udpPacket{buf, n, addr, nil}:
// Packet successfully queued
default:
// Queue is full, drop the packet
u.bufPool.Put(buf)
}
}
}
func (u *udpHopPacketConn) hopLoop() {
ticker := time.NewTicker(u.HopInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
u.hop()
case <-u.closeChan:
return
}
}
}
func (u *udpHopPacketConn) hop() {
u.connMutex.Lock()
defer u.connMutex.Unlock()
if u.closed {
return
}
newConn, err := net.ListenUDP("udp", nil)
if err != nil {
// Could be temporary, just skip this hop
return
}
// We need to keep receiving packets from the previous connection,
// because otherwise there will be packet loss due to the time gap
// between we hop to a new port and the server acknowledges this change.
// So we do the following:
// Close prevConn,
// move currentConn to prevConn,
// set newConn as currentConn,
// start recvLoop on newConn.
if u.prevConn != nil {
_ = u.prevConn.Close() // recvLoop for this conn will exit
}
u.prevConn = u.currentConn
u.currentConn = newConn
// Set buffer sizes if previously set
if u.readBufferSize > 0 {
_ = trySetReadBuffer(u.currentConn, u.readBufferSize)
}
if u.writeBufferSize > 0 {
_ = trySetWriteBuffer(u.currentConn, u.writeBufferSize)
}
go u.recvLoop(newConn)
// Update addrIndex to a new random value
u.addrIndex = rand.Intn(len(u.Addrs))
log.Infoln("hopped to %s", u.Addrs[u.addrIndex])
}
func (u *udpHopPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
for {
select {
case p := <-u.recvQueue:
if p.Err != nil {
return 0, nil, p.Err
}
// Currently we do not check whether the packet is from
// the server or not due to performance reasons.
n := copy(b, p.Buf[:p.N])
u.bufPool.Put(p.Buf)
return n, u.Addr, nil
case <-u.closeChan:
return 0, nil, net.ErrClosed
}
}
}
func (u *udpHopPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
u.connMutex.RLock()
defer u.connMutex.RUnlock()
if u.closed {
return 0, net.ErrClosed
}
// Skip the check for now, always write to the server,
// for the same reason as in ReadFrom.
return u.currentConn.WriteTo(b, u.Addrs[u.addrIndex])
}
func (u *udpHopPacketConn) Close() error {
u.connMutex.Lock()
defer u.connMutex.Unlock()
if u.closed {
return nil
}
// Close prevConn and currentConn
// Close closeChan to unblock ReadFrom & hopLoop
// Set closed flag to true to prevent double close
if u.prevConn != nil {
_ = u.prevConn.Close()
}
err := u.currentConn.Close()
close(u.closeChan)
u.closed = true
u.Addrs = nil // For GC
return err
}
func (u *udpHopPacketConn) LocalAddr() net.Addr {
u.connMutex.RLock()
defer u.connMutex.RUnlock()
return u.currentConn.LocalAddr()
}
func (u *udpHopPacketConn) SetDeadline(t time.Time) error {
u.connMutex.RLock()
defer u.connMutex.RUnlock()
if u.prevConn != nil {
_ = u.prevConn.SetDeadline(t)
}
return u.currentConn.SetDeadline(t)
}
func (u *udpHopPacketConn) SetReadDeadline(t time.Time) error {
u.connMutex.RLock()
defer u.connMutex.RUnlock()
if u.prevConn != nil {
_ = u.prevConn.SetReadDeadline(t)
}
return u.currentConn.SetReadDeadline(t)
}
func (u *udpHopPacketConn) SetWriteDeadline(t time.Time) error {
u.connMutex.RLock()
defer u.connMutex.RUnlock()
if u.prevConn != nil {
_ = u.prevConn.SetWriteDeadline(t)
}
return u.currentConn.SetWriteDeadline(t)
}
// UDP-specific methods below
func (u *udpHopPacketConn) SetReadBuffer(bytes int) error {
u.connMutex.Lock()
defer u.connMutex.Unlock()
u.readBufferSize = bytes
if u.prevConn != nil {
_ = trySetReadBuffer(u.prevConn, bytes)
}
return trySetReadBuffer(u.currentConn, bytes)
}
func (u *udpHopPacketConn) SetWriteBuffer(bytes int) error {
u.connMutex.Lock()
defer u.connMutex.Unlock()
u.writeBufferSize = bytes
if u.prevConn != nil {
_ = trySetWriteBuffer(u.prevConn, bytes)
}
return trySetWriteBuffer(u.currentConn, bytes)
}
func (u *udpHopPacketConn) SyscallConn() (syscall.RawConn, error) {
u.connMutex.RLock()
defer u.connMutex.RUnlock()
sc, ok := u.currentConn.(syscall.Conn)
if !ok {
return nil, errors.New("not supported")
}
return sc.SyscallConn()
}
func trySetReadBuffer(pc net.PacketConn, bytes int) error {
sc, ok := pc.(interface {
SetReadBuffer(bytes int) error
})
if ok {
return sc.SetReadBuffer(bytes)
}
return nil
}
func trySetWriteBuffer(pc net.PacketConn, bytes int) error {
sc, ok := pc.(interface {
SetWriteBuffer(bytes int) error
})
if ok {
return sc.SetWriteBuffer(bytes)
}
return nil
}