mirror of
https://github.com/MetaCubeX/Clash.Meta.git
synced 2025-04-20 01:00:56 +00:00
support port hopping for hysteria 2
This commit is contained in:
parent
460cc240b0
commit
2a0139e236
36 changed files with 5282 additions and 101 deletions
|
@ -2,137 +2,123 @@ package outbound
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
CN "github.com/metacubex/mihomo/common/net"
|
||||
"github.com/metacubex/mihomo/component/ca"
|
||||
"github.com/metacubex/mihomo/component/dialer"
|
||||
"github.com/metacubex/mihomo/component/proxydialer"
|
||||
C "github.com/metacubex/mihomo/constant"
|
||||
tuicCommon "github.com/metacubex/mihomo/transport/tuic/common"
|
||||
|
||||
"github.com/metacubex/sing-quic/hysteria2"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/app/cmd"
|
||||
hy2client "github.com/metacubex/mihomo/transport/hysteria2/core/client"
|
||||
)
|
||||
|
||||
func init() {
|
||||
hysteria2.SetCongestionController = tuicCommon.SetCongestionController
|
||||
}
|
||||
const minHopInterval = 5
|
||||
const defaultHopInterval = 30
|
||||
|
||||
type Hysteria2 struct {
|
||||
*Base
|
||||
|
||||
option *Hysteria2Option
|
||||
client *hysteria2.Client
|
||||
dialer proxydialer.SingDialer
|
||||
client hy2client.Client
|
||||
}
|
||||
|
||||
type Hysteria2Option struct {
|
||||
BasicOption
|
||||
Name string `proxy:"name"`
|
||||
Server string `proxy:"server"`
|
||||
Port int `proxy:"port"`
|
||||
Up string `proxy:"up,omitempty"`
|
||||
Down string `proxy:"down,omitempty"`
|
||||
Password string `proxy:"password,omitempty"`
|
||||
Obfs string `proxy:"obfs,omitempty"`
|
||||
ObfsPassword string `proxy:"obfs-password,omitempty"`
|
||||
SNI string `proxy:"sni,omitempty"`
|
||||
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
|
||||
Fingerprint string `proxy:"fingerprint,omitempty"`
|
||||
ALPN []string `proxy:"alpn,omitempty"`
|
||||
CustomCA string `proxy:"ca,omitempty"`
|
||||
CustomCAString string `proxy:"ca-str,omitempty"`
|
||||
CWND int `proxy:"cwnd,omitempty"`
|
||||
Name string `proxy:"name"`
|
||||
Server string `proxy:"server"`
|
||||
Port uint16 `proxy:"port,omitempty"`
|
||||
Ports string `proxy:"ports,omitempty"`
|
||||
HopInterval time.Duration `proxy:"hop-interval,omitempty"`
|
||||
Up string `proxy:"up"`
|
||||
Down string `proxy:"down"`
|
||||
Password string `proxy:"password,omitempty"`
|
||||
Obfs string `proxy:"obfs,omitempty"`
|
||||
ObfsPassword string `proxy:"obfs-password,omitempty"`
|
||||
SNI string `proxy:"sni,omitempty"`
|
||||
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
|
||||
Fingerprint string `proxy:"fingerprint,omitempty"`
|
||||
ALPN []string `proxy:"alpn,omitempty"`
|
||||
CustomCA string `proxy:"ca,omitempty"`
|
||||
CustomCAString string `proxy:"ca-str,omitempty"`
|
||||
CWND int `proxy:"cwnd,omitempty"`
|
||||
FastOpen bool `proxy:"fast-open,omitempty"`
|
||||
Lazy bool `proxy:"lazy,omitempty"`
|
||||
}
|
||||
|
||||
func (h *Hysteria2) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
|
||||
options := h.Base.DialOptions(opts...)
|
||||
h.dialer.SetDialer(dialer.NewDialer(options...))
|
||||
c, err := h.client.DialConn(ctx, M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort))
|
||||
tcpConn, err := h.client.TCP(net.JoinHostPort(metadata.String(), strconv.Itoa(int(metadata.DstPort))))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewConn(CN.NewRefConn(c, h), h), nil
|
||||
|
||||
return NewConn(tcpConn, h), nil
|
||||
}
|
||||
|
||||
func (h *Hysteria2) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
|
||||
options := h.Base.DialOptions(opts...)
|
||||
h.dialer.SetDialer(dialer.NewDialer(options...))
|
||||
pc, err := h.client.ListenPacket(ctx)
|
||||
udpConn, err := h.client.UDP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pc == nil {
|
||||
return nil, errors.New("packetConn is nil")
|
||||
}
|
||||
return newPacketConn(CN.NewRefPacketConn(CN.NewThreadSafePacketConn(pc), h), h), nil
|
||||
}
|
||||
|
||||
func closeHysteria2(h *Hysteria2) {
|
||||
if h.client != nil {
|
||||
_ = h.client.CloseWithError(errors.New("proxy removed"))
|
||||
}
|
||||
return newPacketConn(udpConn, h), nil
|
||||
}
|
||||
|
||||
func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
|
||||
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
|
||||
var salamanderPassword string
|
||||
if len(option.Obfs) > 0 {
|
||||
if option.ObfsPassword == "" {
|
||||
return nil, errors.New("missing obfs password")
|
||||
}
|
||||
switch option.Obfs {
|
||||
case hysteria2.ObfsTypeSalamander:
|
||||
salamanderPassword = option.ObfsPassword
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown obfs type: %s", option.Obfs)
|
||||
}
|
||||
var server string
|
||||
if option.Ports != "" {
|
||||
server = net.JoinHostPort(option.Server, option.Ports)
|
||||
} else {
|
||||
server = net.JoinHostPort(option.Server, strconv.Itoa(int(option.Port)))
|
||||
}
|
||||
|
||||
serverName := option.Server
|
||||
if option.SNI != "" {
|
||||
serverName = option.SNI
|
||||
if option.HopInterval == 0 {
|
||||
option.HopInterval = defaultHopInterval
|
||||
} else if option.HopInterval < minHopInterval {
|
||||
option.HopInterval = minHopInterval
|
||||
}
|
||||
option.HopInterval *= time.Second
|
||||
|
||||
config := cmd.ClientConfig{
|
||||
Server: server,
|
||||
Auth: option.Password,
|
||||
Transport: cmd.ClientConfigTransport{
|
||||
UDP: cmd.ClientConfigTransportUDP{
|
||||
HopInterval: option.HopInterval,
|
||||
},
|
||||
},
|
||||
TLS: cmd.ClientConfigTLS{
|
||||
SNI: option.SNI,
|
||||
Insecure: option.SkipCertVerify,
|
||||
PinSHA256: option.Fingerprint,
|
||||
CA: option.CustomCA,
|
||||
CAString: option.CustomCAString,
|
||||
},
|
||||
FastOpen: option.FastOpen,
|
||||
Lazy: option.Lazy,
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: serverName,
|
||||
InsecureSkipVerify: option.SkipCertVerify,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
if option.ObfsPassword != "" {
|
||||
config.Obfs.Type = "salamander"
|
||||
config.Obfs.Salamander.Password = option.ObfsPassword
|
||||
} else if option.Obfs != "" {
|
||||
config.Obfs.Type = "salamander"
|
||||
config.Obfs.Salamander.Password = option.Obfs
|
||||
}
|
||||
|
||||
var err error
|
||||
tlsConfig, err = ca.GetTLSConfig(tlsConfig, option.Fingerprint, option.CustomCA, option.CustomCAString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
last := option.Up[len(option.Up)-1]
|
||||
if '0' <= last && last <= '9' {
|
||||
option.Up += "m"
|
||||
}
|
||||
|
||||
if len(option.ALPN) > 0 {
|
||||
tlsConfig.NextProtos = option.ALPN
|
||||
config.Bandwidth.Up = option.Up
|
||||
last = option.Down[len(option.Down)-1]
|
||||
if '0' <= last && last <= '9' {
|
||||
option.Down += "m"
|
||||
}
|
||||
config.Bandwidth.Down = option.Down
|
||||
|
||||
singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer())
|
||||
|
||||
clientOptions := hysteria2.ClientOptions{
|
||||
Context: context.TODO(),
|
||||
Dialer: singDialer,
|
||||
ServerAddress: M.ParseSocksaddrHostPort(option.Server, uint16(option.Port)),
|
||||
SendBPS: StringToBps(option.Up),
|
||||
ReceiveBPS: StringToBps(option.Down),
|
||||
SalamanderPassword: salamanderPassword,
|
||||
Password: option.Password,
|
||||
TLSConfig: tlsConfig,
|
||||
UDPDisabled: false,
|
||||
CWND: option.CWND,
|
||||
}
|
||||
|
||||
client, err := hysteria2.NewClient(clientOptions)
|
||||
client, err := hy2client.NewReconnectableClient(
|
||||
config.Config,
|
||||
func(c hy2client.Client, info *hy2client.HandshakeInfo, count int) {},
|
||||
option.Lazy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -140,7 +126,7 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
|
|||
outbound := &Hysteria2{
|
||||
Base: &Base{
|
||||
name: option.Name,
|
||||
addr: addr,
|
||||
addr: server,
|
||||
tp: C.Hysteria2,
|
||||
udp: true,
|
||||
iface: option.Interface,
|
||||
|
@ -149,9 +135,7 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
|
|||
},
|
||||
option: &option,
|
||||
client: client,
|
||||
dialer: singDialer,
|
||||
}
|
||||
runtime.SetFinalizer(outbound, closeHysteria2)
|
||||
|
||||
return outbound, nil
|
||||
}
|
||||
|
|
8
go.mod
8
go.mod
|
@ -1,6 +1,6 @@
|
|||
module github.com/metacubex/mihomo
|
||||
|
||||
go 1.20
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/3andne/restls-client-go v0.1.6
|
||||
|
@ -65,11 +65,12 @@ require (
|
|||
github.com/andybalholm/brotli v1.0.6 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/cloudflare/circl v1.3.6 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/ericlagergren/aegis v0.0.0-20230312195928-b4ce538b56f9 // indirect
|
||||
github.com/ericlagergren/polyval v0.0.0-20220411101811-e25bc10ba391 // indirect
|
||||
github.com/ericlagergren/siv v0.0.0-20220507050439-0b757b3aa5f1 // indirect
|
||||
github.com/ericlagergren/subtle v0.0.0-20220507045147-890d697da010 // indirect
|
||||
github.com/frankban/quicktest v1.14.6 // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/gaukas/godicttls v0.0.4 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
|
@ -89,7 +90,7 @@ require (
|
|||
github.com/oasisprotocol/deoxysii v0.0.0-20220228165953-2091330c22b7 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.14 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/quic-go/qpack v0.4.0 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
|
||||
|
@ -110,6 +111,7 @@ require (
|
|||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.16.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
|
||||
)
|
||||
|
||||
replace github.com/sagernet/sing => github.com/metacubex/sing v0.0.0-20240111014253-f1818b6a82b2
|
||||
|
|
17
go.sum
17
go.sum
|
@ -25,9 +25,11 @@ github.com/cloudflare/circl v1.3.6 h1:/xbKIqSHbZXHwkhbrhrt2YOHIwYJlXH94E3tI/gDlU
|
|||
github.com/cloudflare/circl v1.3.6/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA=
|
||||
github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8=
|
||||
github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/ericlagergren/aegis v0.0.0-20230312195928-b4ce538b56f9 h1:/5RkVc9Rc81XmMyVqawCiDyrBHZbLAZgTTCqou4mwj8=
|
||||
|
@ -39,7 +41,8 @@ github.com/ericlagergren/siv v0.0.0-20220507050439-0b757b3aa5f1 h1:tlDMEdcPRQKBE
|
|||
github.com/ericlagergren/siv v0.0.0-20220507050439-0b757b3aa5f1/go.mod h1:4RfsapbGx2j/vU5xC/5/9qB3kn9Awp1YDiEnN43QrJ4=
|
||||
github.com/ericlagergren/subtle v0.0.0-20220507045147-890d697da010 h1:fuGucgPk5dN6wzfnxl3D0D3rVLw4v2SbBT9jb4VnxzA=
|
||||
github.com/ericlagergren/subtle v0.0.0-20220507045147-890d697da010/go.mod h1:JtBcj7sBuTTRupn7c2bFspMDIObMJsVK8TeUvpShPok=
|
||||
github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/gaukas/godicttls v0.0.4 h1:NlRaXb3J6hAnTmWdsEKb9bcSBD6BvcIjdGdeb0zfXbk=
|
||||
|
@ -91,7 +94,9 @@ github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6K
|
|||
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 h1:EnfXoSqDfSNJv0VBNqY/88RNnhSGYkrHaO0mmFGbVsc=
|
||||
|
@ -140,9 +145,11 @@ github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq5
|
|||
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
|
||||
github.com/pierrec/lz4/v4 v4.1.14 h1:+fL8AQEZtz/ijeNnpduH0bROTu0O3NZAlPjQxGn8LwE=
|
||||
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||
|
@ -153,6 +160,7 @@ github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1
|
|||
github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5nfFs=
|
||||
github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a h1:+NkI2670SQpQWvkkD2QgdTuzQG263YZ+2emfpeyGqW0=
|
||||
github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a/go.mod h1:63s7jpZqcDAIpj8oI/1v4Izok+npJOHACFCU6+huCkM=
|
||||
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE=
|
||||
|
@ -270,7 +278,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
|
|||
google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I=
|
||||
google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
1
transport/hysteria2/README.md
Normal file
1
transport/hysteria2/README.md
Normal file
|
@ -0,0 +1 @@
|
|||
Copied from [hysteria](https://github.com/apernet/hysteria) v2.2.3 with a little changes.
|
374
transport/hysteria2/app/cmd/client.go
Normal file
374
transport/hysteria2/app/cmd/client.go
Normal file
|
@ -0,0 +1,374 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/app/utils"
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/client"
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/extras/obfs"
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/extras/transport/udphop"
|
||||
)
|
||||
|
||||
type ClientConfig struct {
|
||||
Server string `mapstructure:"server"`
|
||||
Auth string `mapstructure:"auth"`
|
||||
Transport ClientConfigTransport `mapstructure:"transport"`
|
||||
Obfs ClientConfigObfs `mapstructure:"obfs"`
|
||||
TLS ClientConfigTLS `mapstructure:"tls"`
|
||||
QUIC clientConfigQUIC `mapstructure:"quic"`
|
||||
Bandwidth ClientConfigBandwidth `mapstructure:"bandwidth"`
|
||||
FastOpen bool `mapstructure:"fastOpen"`
|
||||
Lazy bool `mapstructure:"lazy"`
|
||||
}
|
||||
|
||||
type ClientConfigTransportUDP struct {
|
||||
HopInterval time.Duration `mapstructure:"hopInterval"`
|
||||
}
|
||||
|
||||
type ClientConfigTransport struct {
|
||||
Type string `mapstructure:"type"`
|
||||
UDP ClientConfigTransportUDP `mapstructure:"udp"`
|
||||
}
|
||||
|
||||
type ClientConfigObfsSalamander struct {
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
|
||||
type ClientConfigObfs struct {
|
||||
Type string `mapstructure:"type"`
|
||||
Salamander ClientConfigObfsSalamander `mapstructure:"salamander"`
|
||||
}
|
||||
|
||||
type ClientConfigTLS struct {
|
||||
SNI string `mapstructure:"sni"`
|
||||
Insecure bool `mapstructure:"insecure"`
|
||||
PinSHA256 string `mapstructure:"pinSHA256"`
|
||||
CA string `mapstructure:"ca"`
|
||||
CAString string `mapstructure:"ca-str"`
|
||||
}
|
||||
|
||||
type clientConfigQUIC struct {
|
||||
InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"`
|
||||
MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"`
|
||||
InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"`
|
||||
MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"`
|
||||
MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"`
|
||||
KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"`
|
||||
DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"`
|
||||
}
|
||||
|
||||
type ClientConfigBandwidth struct {
|
||||
Up string `mapstructure:"up"`
|
||||
Down string `mapstructure:"down"`
|
||||
}
|
||||
|
||||
func (c *ClientConfig) fillServerAddr(hyConfig *client.Config) error {
|
||||
if c.Server == "" {
|
||||
return configError{Field: "server", Err: errors.New("server address is empty")}
|
||||
}
|
||||
var addr net.Addr
|
||||
var err error
|
||||
host, port, hostPort := parseServerAddrString(c.Server)
|
||||
if !isPortHoppingPort(port) {
|
||||
addr, err = net.ResolveUDPAddr("udp", hostPort)
|
||||
} else {
|
||||
addr, err = udphop.ResolveUDPHopAddr(hostPort)
|
||||
}
|
||||
if err != nil {
|
||||
return configError{Field: "server", Err: err}
|
||||
}
|
||||
hyConfig.ServerAddr = addr
|
||||
// Special handling for SNI
|
||||
if c.TLS.SNI == "" {
|
||||
// Use server hostname as SNI
|
||||
hyConfig.TLSConfig.ServerName = host
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fillConnFactory must be called after fillServerAddr, as we have different logic
|
||||
// for ConnFactory depending on whether we have a port hopping address.
|
||||
func (c *ClientConfig) fillConnFactory(hyConfig *client.Config) error {
|
||||
// Inner PacketConn
|
||||
var newFunc func(addr net.Addr) (net.PacketConn, error)
|
||||
switch strings.ToLower(c.Transport.Type) {
|
||||
case "", "udp":
|
||||
if hyConfig.ServerAddr.Network() == "udphop" {
|
||||
hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr)
|
||||
newFunc = func(addr net.Addr) (net.PacketConn, error) {
|
||||
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval)
|
||||
}
|
||||
} else {
|
||||
newFunc = func(addr net.Addr) (net.PacketConn, error) {
|
||||
return net.ListenUDP("udp", nil)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return configError{Field: "transport.type", Err: errors.New("unsupported transport type")}
|
||||
}
|
||||
// Obfuscation
|
||||
var ob obfs.Obfuscator
|
||||
var err error
|
||||
switch strings.ToLower(c.Obfs.Type) {
|
||||
case "", "plain":
|
||||
// Keep it nil
|
||||
case "salamander":
|
||||
ob, err = obfs.NewSalamanderObfuscator([]byte(c.Obfs.Salamander.Password))
|
||||
if err != nil {
|
||||
return configError{Field: "obfs.salamander.password", Err: err}
|
||||
}
|
||||
default:
|
||||
return configError{Field: "obfs.type", Err: errors.New("unsupported obfuscation type")}
|
||||
}
|
||||
hyConfig.ConnFactory = &adaptiveConnFactory{
|
||||
NewFunc: newFunc,
|
||||
Obfuscator: ob,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientConfig) fillAuth(hyConfig *client.Config) error {
|
||||
hyConfig.Auth = c.Auth
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientConfig) fillTLSConfig(hyConfig *client.Config) error {
|
||||
if c.TLS.SNI != "" {
|
||||
hyConfig.TLSConfig.ServerName = c.TLS.SNI
|
||||
}
|
||||
hyConfig.TLSConfig.InsecureSkipVerify = c.TLS.Insecure
|
||||
if c.TLS.PinSHA256 != "" {
|
||||
nHash := normalizeCertHash(c.TLS.PinSHA256)
|
||||
hyConfig.TLSConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
for _, cert := range rawCerts {
|
||||
hash := sha256.Sum256(cert)
|
||||
hashHex := hex.EncodeToString(hash[:])
|
||||
if hashHex == nHash {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// No match
|
||||
return errors.New("no certificate matches the pinned hash")
|
||||
}
|
||||
}
|
||||
if c.TLS.CAString != "" || c.TLS.CA != "" {
|
||||
var ca []byte
|
||||
if c.TLS.CAString != "" {
|
||||
ca = []byte(c.TLS.CAString)
|
||||
} else {
|
||||
var err error
|
||||
ca, err = os.ReadFile(c.TLS.CA)
|
||||
if err != nil {
|
||||
return configError{Field: "tls.ca", Err: err}
|
||||
}
|
||||
}
|
||||
cPool := x509.NewCertPool()
|
||||
if !cPool.AppendCertsFromPEM(ca) {
|
||||
return configError{Field: "tls.ca", Err: errors.New("failed to parse CA certificate")}
|
||||
}
|
||||
hyConfig.TLSConfig.RootCAs = cPool
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientConfig) fillQUICConfig(hyConfig *client.Config) error {
|
||||
hyConfig.QUICConfig = client.QUICConfig{
|
||||
InitialStreamReceiveWindow: c.QUIC.InitStreamReceiveWindow,
|
||||
MaxStreamReceiveWindow: c.QUIC.MaxStreamReceiveWindow,
|
||||
InitialConnectionReceiveWindow: c.QUIC.InitConnectionReceiveWindow,
|
||||
MaxConnectionReceiveWindow: c.QUIC.MaxConnectionReceiveWindow,
|
||||
MaxIdleTimeout: c.QUIC.MaxIdleTimeout,
|
||||
KeepAlivePeriod: c.QUIC.KeepAlivePeriod,
|
||||
DisablePathMTUDiscovery: c.QUIC.DisablePathMTUDiscovery,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientConfig) fillBandwidthConfig(hyConfig *client.Config) error {
|
||||
// New core now allows users to omit bandwidth values and use built-in congestion control
|
||||
var err error
|
||||
if c.Bandwidth.Up != "" {
|
||||
hyConfig.BandwidthConfig.MaxTx, err = utils.ConvBandwidth(c.Bandwidth.Up)
|
||||
if err != nil {
|
||||
return configError{Field: "bandwidth.up", Err: err}
|
||||
}
|
||||
}
|
||||
if c.Bandwidth.Down != "" {
|
||||
hyConfig.BandwidthConfig.MaxRx, err = utils.ConvBandwidth(c.Bandwidth.Down)
|
||||
if err != nil {
|
||||
return configError{Field: "bandwidth.down", Err: err}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientConfig) fillFastOpen(hyConfig *client.Config) error {
|
||||
hyConfig.FastOpen = c.FastOpen
|
||||
return nil
|
||||
}
|
||||
|
||||
// URI generates a URI for sharing the config with others.
|
||||
// Note that only the bare minimum of information required to
|
||||
// connect to the server is included in the URI, specifically:
|
||||
// - server address
|
||||
// - authentication
|
||||
// - obfuscation type
|
||||
// - obfuscation password
|
||||
// - TLS SNI
|
||||
// - TLS insecure
|
||||
// - TLS pinned SHA256 hash (normalized)
|
||||
func (c *ClientConfig) URI() string {
|
||||
q := url.Values{}
|
||||
switch strings.ToLower(c.Obfs.Type) {
|
||||
case "salamander":
|
||||
q.Set("obfs", "salamander")
|
||||
q.Set("obfs-password", c.Obfs.Salamander.Password)
|
||||
}
|
||||
if c.TLS.SNI != "" {
|
||||
q.Set("sni", c.TLS.SNI)
|
||||
}
|
||||
if c.TLS.Insecure {
|
||||
q.Set("insecure", "1")
|
||||
}
|
||||
if c.TLS.PinSHA256 != "" {
|
||||
q.Set("pinSHA256", normalizeCertHash(c.TLS.PinSHA256))
|
||||
}
|
||||
var user *url.Userinfo
|
||||
if c.Auth != "" {
|
||||
// We need to handle the special case of user:pass pairs
|
||||
rs := strings.SplitN(c.Auth, ":", 2)
|
||||
if len(rs) == 2 {
|
||||
user = url.UserPassword(rs[0], rs[1])
|
||||
} else {
|
||||
user = url.User(c.Auth)
|
||||
}
|
||||
}
|
||||
u := url.URL{
|
||||
Scheme: "hysteria2",
|
||||
User: user,
|
||||
Host: c.Server,
|
||||
Path: "/",
|
||||
RawQuery: q.Encode(),
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// parseURI tries to parse the server address field as a URI,
|
||||
// and fills the config with the information contained in the URI.
|
||||
// Returns whether the server address field is a valid URI.
|
||||
// This allows a user to use put a URI as the server address and
|
||||
// omit the fields that are already contained in the URI.
|
||||
func (c *ClientConfig) parseURI() bool {
|
||||
u, err := url.Parse(c.Server)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if u.Scheme != "hysteria2" && u.Scheme != "hy2" {
|
||||
return false
|
||||
}
|
||||
if u.User != nil {
|
||||
c.Auth = u.User.String()
|
||||
}
|
||||
c.Server = u.Host
|
||||
q := u.Query()
|
||||
if obfsType := q.Get("obfs"); obfsType != "" {
|
||||
c.Obfs.Type = obfsType
|
||||
switch strings.ToLower(obfsType) {
|
||||
case "salamander":
|
||||
c.Obfs.Salamander.Password = q.Get("obfs-password")
|
||||
}
|
||||
}
|
||||
if sni := q.Get("sni"); sni != "" {
|
||||
c.TLS.SNI = sni
|
||||
}
|
||||
if insecure, err := strconv.ParseBool(q.Get("insecure")); err == nil {
|
||||
c.TLS.Insecure = insecure
|
||||
}
|
||||
if pinSHA256 := q.Get("pinSHA256"); pinSHA256 != "" {
|
||||
c.TLS.PinSHA256 = pinSHA256
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Config validates the fields and returns a ready-to-use Hysteria client config
|
||||
func (c *ClientConfig) Config() (*client.Config, error) {
|
||||
c.parseURI()
|
||||
hyConfig := &client.Config{}
|
||||
fillers := []func(*client.Config) error{
|
||||
c.fillServerAddr,
|
||||
c.fillConnFactory,
|
||||
c.fillAuth,
|
||||
c.fillTLSConfig,
|
||||
c.fillQUICConfig,
|
||||
c.fillBandwidthConfig,
|
||||
c.fillFastOpen,
|
||||
}
|
||||
for _, f := range fillers {
|
||||
if err := f(hyConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return hyConfig, nil
|
||||
}
|
||||
|
||||
type clientModeRunner struct {
|
||||
ModeMap map[string]func() error
|
||||
}
|
||||
|
||||
func (r *clientModeRunner) Add(name string, f func() error) {
|
||||
if r.ModeMap == nil {
|
||||
r.ModeMap = make(map[string]func() error)
|
||||
}
|
||||
r.ModeMap[name] = f
|
||||
}
|
||||
|
||||
// parseServerAddrString parses server address string.
|
||||
// Server address can be in either "host:port" or "host" format (in which case we assume port 443).
|
||||
func parseServerAddrString(addrStr string) (host, port, hostPort string) {
|
||||
h, p, err := net.SplitHostPort(addrStr)
|
||||
if err != nil {
|
||||
return addrStr, "443", net.JoinHostPort(addrStr, "443")
|
||||
}
|
||||
return h, p, addrStr
|
||||
}
|
||||
|
||||
// isPortHoppingPort returns whether the port string is a port hopping port.
|
||||
// We consider a port string to be a port hopping port if it contains "-" or ",".
|
||||
func isPortHoppingPort(port string) bool {
|
||||
return strings.Contains(port, "-") || strings.Contains(port, ",")
|
||||
}
|
||||
|
||||
// normalizeCertHash normalizes a certificate hash string.
|
||||
// It converts all characters to lowercase and removes possible separators such as ":" and "-".
|
||||
func normalizeCertHash(hash string) string {
|
||||
r := strings.ToLower(hash)
|
||||
r = strings.ReplaceAll(r, ":", "")
|
||||
r = strings.ReplaceAll(r, "-", "")
|
||||
return r
|
||||
}
|
||||
|
||||
type adaptiveConnFactory struct {
|
||||
NewFunc func(addr net.Addr) (net.PacketConn, error)
|
||||
Obfuscator obfs.Obfuscator // nil if no obfuscation
|
||||
}
|
||||
|
||||
func (f *adaptiveConnFactory) New(addr net.Addr) (net.PacketConn, error) {
|
||||
if f.Obfuscator == nil {
|
||||
return f.NewFunc(addr)
|
||||
} else {
|
||||
conn, err := f.NewFunc(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return obfs.WrapPacketConn(conn, f.Obfuscator), nil
|
||||
}
|
||||
}
|
18
transport/hysteria2/app/cmd/errors.go
Normal file
18
transport/hysteria2/app/cmd/errors.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type configError struct {
|
||||
Field string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e configError) Error() string {
|
||||
return fmt.Sprintf("invalid config: %s: %s", e.Field, e.Err)
|
||||
}
|
||||
|
||||
func (e configError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
68
transport/hysteria2/app/utils/bpsconv.go
Normal file
68
transport/hysteria2/app/utils/bpsconv.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
Byte = 1
|
||||
Kilobyte = Byte * 1000
|
||||
Megabyte = Kilobyte * 1000
|
||||
Gigabyte = Megabyte * 1000
|
||||
Terabyte = Gigabyte * 1000
|
||||
)
|
||||
|
||||
// StringToBps converts a string to a bandwidth value in bytes per second.
|
||||
// E.g. "100 Mbps", "512 kbps", "1g" are all valid.
|
||||
func StringToBps(s string) (uint64, error) {
|
||||
s = strings.ToLower(strings.TrimSpace(s))
|
||||
spl := 0
|
||||
for i, c := range s {
|
||||
if c < '0' || c > '9' {
|
||||
spl = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if spl == 0 {
|
||||
// No unit or no value
|
||||
return 0, errors.New("invalid format")
|
||||
}
|
||||
v, err := strconv.ParseUint(s[:spl], 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
unit := strings.TrimSpace(s[spl:])
|
||||
|
||||
switch strings.ToLower(unit) {
|
||||
case "b", "bps":
|
||||
return v * Byte / 8, nil
|
||||
case "k", "kb", "kbps":
|
||||
return v * Kilobyte / 8, nil
|
||||
case "m", "mb", "mbps":
|
||||
return v * Megabyte / 8, nil
|
||||
case "g", "gb", "gbps":
|
||||
return v * Gigabyte / 8, nil
|
||||
case "t", "tb", "tbps":
|
||||
return v * Terabyte / 8, nil
|
||||
default:
|
||||
return 0, errors.New("unsupported unit")
|
||||
}
|
||||
}
|
||||
|
||||
// ConvBandwidth handles both string and int types for bandwidth.
|
||||
// When using string, it will be parsed as a bandwidth string with units.
|
||||
// When using int, it will be parsed as a raw bandwidth in bytes per second.
|
||||
// It does NOT support float types.
|
||||
func ConvBandwidth(bw interface{}) (uint64, error) {
|
||||
switch bwT := bw.(type) {
|
||||
case string:
|
||||
return StringToBps(bwT)
|
||||
case int:
|
||||
return uint64(bwT), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid type %T for bandwidth", bwT)
|
||||
}
|
||||
}
|
316
transport/hysteria2/core/client/client.go
Normal file
316
transport/hysteria2/core/client/client.go
Normal file
|
@ -0,0 +1,316 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
coreErrs "github.com/metacubex/mihomo/transport/hysteria2/core/errors"
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion"
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/protocol"
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/utils"
|
||||
|
||||
"github.com/metacubex/quic-go"
|
||||
"github.com/metacubex/quic-go/http3"
|
||||
)
|
||||
|
||||
const (
|
||||
closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError
|
||||
closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
TCP(addr string) (net.Conn, error)
|
||||
UDP() (HyUDPConn, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type HyUDPConn interface {
|
||||
Receive() ([]byte, string, error)
|
||||
Send([]byte, string) error
|
||||
net.PacketConn
|
||||
}
|
||||
|
||||
type HandshakeInfo struct {
|
||||
UDPEnabled bool
|
||||
Tx uint64 // 0 if using BBR
|
||||
}
|
||||
|
||||
func NewClient(config *Config) (Client, *HandshakeInfo, error) {
|
||||
if err := config.verifyAndFill(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
c := &clientImpl{
|
||||
config: config,
|
||||
}
|
||||
info, err := c.connect()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return c, info, nil
|
||||
}
|
||||
|
||||
type clientImpl struct {
|
||||
config *Config
|
||||
|
||||
pktConn net.PacketConn
|
||||
conn quic.Connection
|
||||
|
||||
udpSM *udpSessionManager
|
||||
}
|
||||
|
||||
func (c *clientImpl) connect() (*HandshakeInfo, error) {
|
||||
pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Convert config to TLS config & QUIC config
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: c.config.TLSConfig.ServerName,
|
||||
InsecureSkipVerify: c.config.TLSConfig.InsecureSkipVerify,
|
||||
VerifyPeerCertificate: c.config.TLSConfig.VerifyPeerCertificate,
|
||||
RootCAs: c.config.TLSConfig.RootCAs,
|
||||
}
|
||||
quicConfig := &quic.Config{
|
||||
InitialStreamReceiveWindow: c.config.QUICConfig.InitialStreamReceiveWindow,
|
||||
MaxStreamReceiveWindow: c.config.QUICConfig.MaxStreamReceiveWindow,
|
||||
InitialConnectionReceiveWindow: c.config.QUICConfig.InitialConnectionReceiveWindow,
|
||||
MaxConnectionReceiveWindow: c.config.QUICConfig.MaxConnectionReceiveWindow,
|
||||
MaxIdleTimeout: c.config.QUICConfig.MaxIdleTimeout,
|
||||
KeepAlivePeriod: c.config.QUICConfig.KeepAlivePeriod,
|
||||
DisablePathMTUDiscovery: c.config.QUICConfig.DisablePathMTUDiscovery,
|
||||
EnableDatagrams: true,
|
||||
}
|
||||
// Prepare RoundTripper
|
||||
var conn quic.EarlyConnection
|
||||
rt := &http3.RoundTripper{
|
||||
EnableDatagrams: true,
|
||||
TLSClientConfig: tlsConfig,
|
||||
QuicConfig: quicConfig,
|
||||
Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
qc, err := quic.DialEarly(ctx, pktConn, c.config.ServerAddr, tlsCfg, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn = qc
|
||||
return qc, nil
|
||||
},
|
||||
}
|
||||
// Send auth HTTP request
|
||||
req := &http.Request{
|
||||
Method: http.MethodPost,
|
||||
URL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: protocol.URLHost,
|
||||
Path: protocol.URLPath,
|
||||
},
|
||||
Header: make(http.Header),
|
||||
}
|
||||
protocol.AuthRequestToHeader(req.Header, protocol.AuthRequest{
|
||||
Auth: c.config.Auth,
|
||||
Rx: c.config.BandwidthConfig.MaxRx,
|
||||
})
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
if conn != nil {
|
||||
_ = conn.CloseWithError(closeErrCodeProtocolError, "")
|
||||
}
|
||||
_ = pktConn.Close()
|
||||
return nil, coreErrs.ConnectError{Err: err}
|
||||
}
|
||||
if resp.StatusCode != protocol.StatusAuthOK {
|
||||
_ = conn.CloseWithError(closeErrCodeProtocolError, "")
|
||||
_ = pktConn.Close()
|
||||
return nil, coreErrs.AuthError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
// Auth OK
|
||||
authResp := protocol.AuthResponseFromHeader(resp.Header)
|
||||
var actualTx uint64
|
||||
if authResp.RxAuto {
|
||||
// Server asks client to use bandwidth detection,
|
||||
// ignore local bandwidth config and use BBR
|
||||
congestion.UseBBR(conn)
|
||||
} else {
|
||||
// actualTx = min(serverRx, clientTx)
|
||||
actualTx = authResp.Rx
|
||||
if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx {
|
||||
// Server doesn't have a limit, or our clientTx is smaller than serverRx
|
||||
actualTx = c.config.BandwidthConfig.MaxTx
|
||||
}
|
||||
if actualTx > 0 {
|
||||
congestion.UseBrutal(conn, actualTx)
|
||||
} else {
|
||||
// We don't know our own bandwidth either, use BBR
|
||||
congestion.UseBBR(conn)
|
||||
}
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
c.pktConn = pktConn
|
||||
c.conn = conn
|
||||
if authResp.UDPEnabled {
|
||||
c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn})
|
||||
}
|
||||
return &HandshakeInfo{
|
||||
UDPEnabled: authResp.UDPEnabled,
|
||||
Tx: actualTx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// openStream wraps the stream with QStream, which handles Close() properly
|
||||
func (c *clientImpl) openStream() (quic.Stream, error) {
|
||||
stream, err := c.conn.OpenStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &utils.QStream{Stream: stream}, nil
|
||||
}
|
||||
|
||||
func (c *clientImpl) TCP(addr string) (net.Conn, error) {
|
||||
stream, err := c.openStream()
|
||||
if err != nil {
|
||||
return nil, wrapIfConnectionClosed(err)
|
||||
}
|
||||
// Send request
|
||||
err = protocol.WriteTCPRequest(stream, addr)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, wrapIfConnectionClosed(err)
|
||||
}
|
||||
if c.config.FastOpen {
|
||||
// Don't wait for the response when fast open is enabled.
|
||||
// Return the connection immediately, defer the response handling
|
||||
// to the first Read() call.
|
||||
return &tcpConn{
|
||||
Orig: stream,
|
||||
PseudoLocalAddr: c.conn.LocalAddr(),
|
||||
PseudoRemoteAddr: c.conn.RemoteAddr(),
|
||||
Established: false,
|
||||
}, nil
|
||||
}
|
||||
// Read response
|
||||
ok, msg, err := protocol.ReadTCPResponse(stream)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, wrapIfConnectionClosed(err)
|
||||
}
|
||||
if !ok {
|
||||
_ = stream.Close()
|
||||
return nil, coreErrs.DialError{Message: msg}
|
||||
}
|
||||
return &tcpConn{
|
||||
Orig: stream,
|
||||
PseudoLocalAddr: c.conn.LocalAddr(),
|
||||
PseudoRemoteAddr: c.conn.RemoteAddr(),
|
||||
Established: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *clientImpl) UDP() (HyUDPConn, error) {
|
||||
if c.udpSM == nil {
|
||||
return nil, coreErrs.DialError{Message: "UDP not enabled"}
|
||||
}
|
||||
return c.udpSM.NewUDP()
|
||||
}
|
||||
|
||||
func (c *clientImpl) Close() error {
|
||||
_ = c.conn.CloseWithError(closeErrCodeOK, "")
|
||||
_ = c.pktConn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrapIfConnectionClosed checks if the error returned by quic-go
|
||||
// indicates that the QUIC connection has been permanently closed,
|
||||
// and if so, wraps the error with coreErrs.ClosedError.
|
||||
// PITFALL: sometimes quic-go has "internal errors" that are not net.Error,
|
||||
// but we still need to treat them as ClosedError.
|
||||
func wrapIfConnectionClosed(err error) error {
|
||||
netErr, ok := err.(net.Error)
|
||||
if !ok || !netErr.Temporary() {
|
||||
return coreErrs.ClosedError{Err: err}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
type tcpConn struct {
|
||||
Orig quic.Stream
|
||||
PseudoLocalAddr net.Addr
|
||||
PseudoRemoteAddr net.Addr
|
||||
Established bool
|
||||
}
|
||||
|
||||
func (c *tcpConn) Read(b []byte) (n int, err error) {
|
||||
if !c.Established {
|
||||
// Read response
|
||||
ok, msg, err := protocol.ReadTCPResponse(c.Orig)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !ok {
|
||||
return 0, coreErrs.DialError{Message: msg}
|
||||
}
|
||||
c.Established = true
|
||||
}
|
||||
return c.Orig.Read(b)
|
||||
}
|
||||
|
||||
func (c *tcpConn) Write(b []byte) (n int, err error) {
|
||||
return c.Orig.Write(b)
|
||||
}
|
||||
|
||||
func (c *tcpConn) Close() error {
|
||||
return c.Orig.Close()
|
||||
}
|
||||
|
||||
func (c *tcpConn) LocalAddr() net.Addr {
|
||||
return c.PseudoLocalAddr
|
||||
}
|
||||
|
||||
func (c *tcpConn) RemoteAddr() net.Addr {
|
||||
return c.PseudoRemoteAddr
|
||||
}
|
||||
|
||||
func (c *tcpConn) SetDeadline(t time.Time) error {
|
||||
return c.Orig.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *tcpConn) SetReadDeadline(t time.Time) error {
|
||||
return c.Orig.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *tcpConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.Orig.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
type udpIOImpl struct {
|
||||
Conn quic.Connection
|
||||
}
|
||||
|
||||
func (io *udpIOImpl) ReceiveMessage() (*protocol.UDPMessage, error) {
|
||||
for {
|
||||
msg, err := io.Conn.ReceiveDatagram(context.Background())
|
||||
if err != nil {
|
||||
// Connection error, this will stop the session manager
|
||||
return nil, err
|
||||
}
|
||||
udpMsg, err := protocol.ParseUDPMessage(msg)
|
||||
if err != nil {
|
||||
// Invalid message, this is fine - just wait for the next
|
||||
continue
|
||||
}
|
||||
return udpMsg, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
|
||||
msgN := msg.Serialize(buf)
|
||||
if msgN < 0 {
|
||||
// Message larger than buffer, silent drop
|
||||
return nil
|
||||
}
|
||||
return io.Conn.SendDatagram(buf[:msgN])
|
||||
}
|
112
transport/hysteria2/core/client/config.go
Normal file
112
transport/hysteria2/core/client/config.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/errors"
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/pmtud"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultStreamReceiveWindow = 8388608 // 8MB
|
||||
defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB
|
||||
defaultMaxIdleTimeout = 30 * time.Second
|
||||
defaultKeepAlivePeriod = 10 * time.Second
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ConnFactory ConnFactory
|
||||
ServerAddr net.Addr
|
||||
Auth string
|
||||
TLSConfig TLSConfig
|
||||
QUICConfig QUICConfig
|
||||
BandwidthConfig BandwidthConfig
|
||||
FastOpen bool
|
||||
|
||||
filled bool // whether the fields have been verified and filled
|
||||
}
|
||||
|
||||
// verifyAndFill fills the fields that are not set by the user with default values when possible,
|
||||
// and returns an error if the user has not set a required field or has set an invalid value.
|
||||
func (c *Config) verifyAndFill() error {
|
||||
if c.filled {
|
||||
return nil
|
||||
}
|
||||
if c.ConnFactory == nil {
|
||||
c.ConnFactory = &udpConnFactory{}
|
||||
}
|
||||
if c.ServerAddr == nil {
|
||||
return errors.ConfigError{Field: "ServerAddr", Reason: "must be set"}
|
||||
}
|
||||
if c.QUICConfig.InitialStreamReceiveWindow == 0 {
|
||||
c.QUICConfig.InitialStreamReceiveWindow = defaultStreamReceiveWindow
|
||||
} else if c.QUICConfig.InitialStreamReceiveWindow < 16384 {
|
||||
return errors.ConfigError{Field: "QUICConfig.InitialStreamReceiveWindow", Reason: "must be at least 16384"}
|
||||
}
|
||||
if c.QUICConfig.MaxStreamReceiveWindow == 0 {
|
||||
c.QUICConfig.MaxStreamReceiveWindow = defaultStreamReceiveWindow
|
||||
} else if c.QUICConfig.MaxStreamReceiveWindow < 16384 {
|
||||
return errors.ConfigError{Field: "QUICConfig.MaxStreamReceiveWindow", Reason: "must be at least 16384"}
|
||||
}
|
||||
if c.QUICConfig.InitialConnectionReceiveWindow == 0 {
|
||||
c.QUICConfig.InitialConnectionReceiveWindow = defaultConnReceiveWindow
|
||||
} else if c.QUICConfig.InitialConnectionReceiveWindow < 16384 {
|
||||
return errors.ConfigError{Field: "QUICConfig.InitialConnectionReceiveWindow", Reason: "must be at least 16384"}
|
||||
}
|
||||
if c.QUICConfig.MaxConnectionReceiveWindow == 0 {
|
||||
c.QUICConfig.MaxConnectionReceiveWindow = defaultConnReceiveWindow
|
||||
} else if c.QUICConfig.MaxConnectionReceiveWindow < 16384 {
|
||||
return errors.ConfigError{Field: "QUICConfig.MaxConnectionReceiveWindow", Reason: "must be at least 16384"}
|
||||
}
|
||||
if c.QUICConfig.MaxIdleTimeout == 0 {
|
||||
c.QUICConfig.MaxIdleTimeout = defaultMaxIdleTimeout
|
||||
} else if c.QUICConfig.MaxIdleTimeout < 4*time.Second || c.QUICConfig.MaxIdleTimeout > 120*time.Second {
|
||||
return errors.ConfigError{Field: "QUICConfig.MaxIdleTimeout", Reason: "must be between 4s and 120s"}
|
||||
}
|
||||
if c.QUICConfig.KeepAlivePeriod == 0 {
|
||||
c.QUICConfig.KeepAlivePeriod = defaultKeepAlivePeriod
|
||||
} else if c.QUICConfig.KeepAlivePeriod < 2*time.Second || c.QUICConfig.KeepAlivePeriod > 60*time.Second {
|
||||
return errors.ConfigError{Field: "QUICConfig.KeepAlivePeriod", Reason: "must be between 2s and 60s"}
|
||||
}
|
||||
c.QUICConfig.DisablePathMTUDiscovery = c.QUICConfig.DisablePathMTUDiscovery || pmtud.DisablePathMTUDiscovery
|
||||
|
||||
c.filled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
type ConnFactory interface {
|
||||
New(net.Addr) (net.PacketConn, error)
|
||||
}
|
||||
|
||||
type udpConnFactory struct{}
|
||||
|
||||
func (f *udpConnFactory) New(addr net.Addr) (net.PacketConn, error) {
|
||||
return net.ListenUDP("udp", nil)
|
||||
}
|
||||
|
||||
// TLSConfig contains the TLS configuration fields that we want to expose to the user.
|
||||
type TLSConfig struct {
|
||||
ServerName string
|
||||
InsecureSkipVerify bool
|
||||
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
|
||||
RootCAs *x509.CertPool
|
||||
}
|
||||
|
||||
// QUICConfig contains the QUIC configuration fields that we want to expose to the user.
|
||||
type QUICConfig struct {
|
||||
InitialStreamReceiveWindow uint64
|
||||
MaxStreamReceiveWindow uint64
|
||||
InitialConnectionReceiveWindow uint64
|
||||
MaxConnectionReceiveWindow uint64
|
||||
MaxIdleTimeout time.Duration
|
||||
KeepAlivePeriod time.Duration
|
||||
DisablePathMTUDiscovery bool // The server may still override this to true on unsupported platforms.
|
||||
}
|
||||
|
||||
// BandwidthConfig describes the maximum bandwidth that the server can use, in bytes per second.
|
||||
type BandwidthConfig struct {
|
||||
MaxTx uint64
|
||||
MaxRx uint64
|
||||
}
|
117
transport/hysteria2/core/client/reconnect.go
Normal file
117
transport/hysteria2/core/client/reconnect.go
Normal file
|
@ -0,0 +1,117 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
coreErrs "github.com/metacubex/mihomo/transport/hysteria2/core/errors"
|
||||
)
|
||||
|
||||
// reconnectableClientImpl is a wrapper of Client, which can reconnect when the connection is closed,
|
||||
// except when the caller explicitly calls Close() to permanently close this client.
|
||||
type reconnectableClientImpl struct {
|
||||
configFunc func() (*Config, error) // called before connecting
|
||||
connectedFunc func(Client, *HandshakeInfo, int) // called when successfully connected
|
||||
client Client
|
||||
count int
|
||||
m sync.Mutex
|
||||
closed bool // permanent close
|
||||
}
|
||||
|
||||
// NewReconnectableClient creates a reconnectable client.
|
||||
// If lazy is true, the client will not connect until the first call to TCP() or UDP().
|
||||
// We use a function for config mainly to delay config evaluation
|
||||
// (which involves DNS resolution) until the actual connection attempt.
|
||||
func NewReconnectableClient(configFunc func() (*Config, error), connectedFunc func(Client, *HandshakeInfo, int), lazy bool) (Client, error) {
|
||||
rc := &reconnectableClientImpl{
|
||||
configFunc: configFunc,
|
||||
connectedFunc: connectedFunc,
|
||||
}
|
||||
if !lazy {
|
||||
if err := rc.reconnect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
func (rc *reconnectableClientImpl) reconnect() error {
|
||||
if rc.client != nil {
|
||||
_ = rc.client.Close()
|
||||
}
|
||||
var info *HandshakeInfo
|
||||
config, err := rc.configFunc()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rc.client, info, err = NewClient(config)
|
||||
if err != nil {
|
||||
return err
|
||||
} else {
|
||||
rc.count++
|
||||
if rc.connectedFunc != nil {
|
||||
rc.connectedFunc(rc, info, rc.count)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) {
|
||||
rc.m.Lock()
|
||||
defer rc.m.Unlock()
|
||||
if rc.closed {
|
||||
return nil, coreErrs.ClosedError{}
|
||||
}
|
||||
if rc.client == nil {
|
||||
// No active connection, connect first
|
||||
if err := rc.reconnect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
conn, err := rc.client.TCP(addr)
|
||||
if _, ok := err.(coreErrs.ClosedError); ok {
|
||||
// Connection closed, reconnect
|
||||
if err := rc.reconnect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rc.client.TCP(addr)
|
||||
} else {
|
||||
// OK or some other temporary error
|
||||
return conn, err
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *reconnectableClientImpl) UDP() (HyUDPConn, error) {
|
||||
rc.m.Lock()
|
||||
defer rc.m.Unlock()
|
||||
if rc.closed {
|
||||
return nil, coreErrs.ClosedError{}
|
||||
}
|
||||
if rc.client == nil {
|
||||
// No active connection, connect first
|
||||
if err := rc.reconnect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
conn, err := rc.client.UDP()
|
||||
if _, ok := err.(coreErrs.ClosedError); ok {
|
||||
// Connection closed, reconnect
|
||||
if err := rc.reconnect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rc.client.UDP()
|
||||
} else {
|
||||
// OK or some other temporary error
|
||||
return conn, err
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *reconnectableClientImpl) Close() error {
|
||||
rc.m.Lock()
|
||||
defer rc.m.Unlock()
|
||||
rc.closed = true
|
||||
if rc.client != nil {
|
||||
return rc.client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
223
transport/hysteria2/core/client/udp.go
Normal file
223
transport/hysteria2/core/client/udp.go
Normal file
|
@ -0,0 +1,223 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
coreErrs "github.com/metacubex/mihomo/transport/hysteria2/core/errors"
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/frag"
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/protocol"
|
||||
"github.com/metacubex/quic-go"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
udpMessageChanSize = 1024
|
||||
)
|
||||
|
||||
type udpIO interface {
|
||||
ReceiveMessage() (*protocol.UDPMessage, error)
|
||||
SendMessage([]byte, *protocol.UDPMessage) error
|
||||
}
|
||||
|
||||
type udpConn struct {
|
||||
ID uint32
|
||||
D *frag.Defragger
|
||||
ReceiveCh chan *protocol.UDPMessage
|
||||
SendBuf []byte
|
||||
SendFunc func([]byte, *protocol.UDPMessage) error
|
||||
CloseFunc func()
|
||||
Closed bool
|
||||
}
|
||||
|
||||
func (u *udpConn) Receive() ([]byte, string, error) {
|
||||
for {
|
||||
msg := <-u.ReceiveCh
|
||||
if msg == nil {
|
||||
// Closed
|
||||
return nil, "", io.EOF
|
||||
}
|
||||
dfMsg := u.D.Feed(msg)
|
||||
if dfMsg == nil {
|
||||
// Incomplete message, wait for more
|
||||
continue
|
||||
}
|
||||
return dfMsg.Data, dfMsg.Addr, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
bs, addrStr, err := u.Receive()
|
||||
n = copy(p, bs)
|
||||
addr = M.ParseSocksaddr(addrStr).UDPAddr()
|
||||
return
|
||||
}
|
||||
|
||||
// Send is not thread-safe, as it uses a shared SendBuf.
|
||||
func (u *udpConn) Send(data []byte, addr string) error {
|
||||
// Try no frag first
|
||||
msg := &protocol.UDPMessage{
|
||||
SessionID: u.ID,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: addr,
|
||||
Data: data,
|
||||
}
|
||||
err := u.SendFunc(u.SendBuf, msg)
|
||||
var errTooLarge quic.ErrMessageTooLarge
|
||||
if errors.As(err, &errTooLarge) {
|
||||
// Message too large, try fragmentation
|
||||
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
|
||||
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
|
||||
for _, fMsg := range fMsgs {
|
||||
err := u.SendFunc(u.SendBuf, &fMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
err = u.Send(p, M.SocksaddrFromNet(addr).String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n = len(p)
|
||||
return
|
||||
}
|
||||
|
||||
func (u *udpConn) Close() error {
|
||||
u.CloseFunc()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *udpConn) LocalAddr() net.Addr {
|
||||
// a fake implementation to satisfy net.PacketConn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *udpConn) SetDeadline(t time.Time) error {
|
||||
// a fake implementation to satisfy net.PacketConn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *udpConn) SetReadDeadline(t time.Time) error {
|
||||
// a fake implementation to satisfy net.PacketConn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *udpConn) SetWriteDeadline(t time.Time) error {
|
||||
// a fake implementation to satisfy net.PacketConn
|
||||
return nil
|
||||
}
|
||||
|
||||
type udpSessionManager struct {
|
||||
io udpIO
|
||||
|
||||
mutex sync.RWMutex
|
||||
m map[uint32]*udpConn
|
||||
nextID uint32
|
||||
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newUDPSessionManager(io udpIO) *udpSessionManager {
|
||||
m := &udpSessionManager{
|
||||
io: io,
|
||||
m: make(map[uint32]*udpConn),
|
||||
nextID: 1,
|
||||
}
|
||||
go m.run()
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) run() error {
|
||||
defer m.closeCleanup()
|
||||
for {
|
||||
msg, err := m.io.ReceiveMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.feed(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) closeCleanup() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
for _, conn := range m.m {
|
||||
m.close(conn)
|
||||
}
|
||||
m.closed = true
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
conn, ok := m.m[msg.SessionID]
|
||||
if !ok {
|
||||
// Ignore message from unknown session
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case conn.ReceiveCh <- msg:
|
||||
// OK
|
||||
default:
|
||||
// Channel full, drop the message
|
||||
}
|
||||
}
|
||||
|
||||
// NewUDP creates a new UDP session.
|
||||
func (m *udpSessionManager) NewUDP() (HyUDPConn, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return nil, coreErrs.ClosedError{}
|
||||
}
|
||||
|
||||
id := m.nextID
|
||||
m.nextID++
|
||||
|
||||
conn := &udpConn{
|
||||
ID: id,
|
||||
D: &frag.Defragger{},
|
||||
ReceiveCh: make(chan *protocol.UDPMessage, udpMessageChanSize),
|
||||
SendBuf: make([]byte, protocol.MaxUDPSize),
|
||||
SendFunc: m.io.SendMessage,
|
||||
}
|
||||
conn.CloseFunc = func() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
m.close(conn)
|
||||
}
|
||||
m.m[id] = conn
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) close(conn *udpConn) {
|
||||
if !conn.Closed {
|
||||
conn.Closed = true
|
||||
close(conn.ReceiveCh)
|
||||
delete(m.m, conn.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) Count() int {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return len(m.m)
|
||||
}
|
75
transport/hysteria2/core/errors/errors.go
Normal file
75
transport/hysteria2/core/errors/errors.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// ConfigError is returned when a configuration field is invalid.
|
||||
type ConfigError struct {
|
||||
Field string
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (c ConfigError) Error() string {
|
||||
return fmt.Sprintf("invalid config: %s: %s", c.Field, c.Reason)
|
||||
}
|
||||
|
||||
// ConnectError is returned when the client fails to connect to the server.
|
||||
type ConnectError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (c ConnectError) Error() string {
|
||||
return "connect error: " + c.Err.Error()
|
||||
}
|
||||
|
||||
func (c ConnectError) Unwrap() error {
|
||||
return c.Err
|
||||
}
|
||||
|
||||
// AuthError is returned when the client fails to authenticate with the server.
|
||||
type AuthError struct {
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func (a AuthError) Error() string {
|
||||
return "authentication error, HTTP status code: " + strconv.Itoa(a.StatusCode)
|
||||
}
|
||||
|
||||
// DialError is returned when the server rejects the client's dial request.
|
||||
// This applies to both TCP and UDP.
|
||||
type DialError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (c DialError) Error() string {
|
||||
return "dial error: " + c.Message
|
||||
}
|
||||
|
||||
// ClosedError is returned when the client attempts to use a closed connection.
|
||||
type ClosedError struct {
|
||||
Err error // Can be nil
|
||||
}
|
||||
|
||||
func (c ClosedError) Error() string {
|
||||
if c.Err == nil {
|
||||
return "connection closed"
|
||||
} else {
|
||||
return "connection closed: " + c.Err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
func (c ClosedError) Unwrap() error {
|
||||
return c.Err
|
||||
}
|
||||
|
||||
// ProtocolError is returned when the server/client runs into an unexpected
|
||||
// or malformed request/response/message.
|
||||
type ProtocolError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (p ProtocolError) Error() string {
|
||||
return "protocol error: " + p.Message
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package bbr
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/quic-go/congestion"
|
||||
)
|
||||
|
||||
const (
|
||||
infBandwidth = Bandwidth(math.MaxUint64)
|
||||
)
|
||||
|
||||
// Bandwidth of a connection
|
||||
type Bandwidth uint64
|
||||
|
||||
const (
|
||||
// BitsPerSecond is 1 bit per second
|
||||
BitsPerSecond Bandwidth = 1
|
||||
// BytesPerSecond is 1 byte per second
|
||||
BytesPerSecond = 8 * BitsPerSecond
|
||||
)
|
||||
|
||||
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
|
||||
func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth {
|
||||
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
|
||||
}
|
|
@ -0,0 +1,874 @@
|
|||
package bbr
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/quic-go/congestion"
|
||||
)
|
||||
|
||||
const (
|
||||
infRTT = time.Duration(math.MaxInt64)
|
||||
defaultConnectionStateMapQueueSize = 256
|
||||
defaultCandidatesBufferSize = 256
|
||||
)
|
||||
|
||||
type roundTripCount uint64
|
||||
|
||||
// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned
|
||||
// to the caller when the packet is acked or lost.
|
||||
type sendTimeState struct {
|
||||
// Whether other states in this object is valid.
|
||||
isValid bool
|
||||
// Whether the sender is app limited at the time the packet was sent.
|
||||
// App limited bandwidth sample might be artificially low because the sender
|
||||
// did not have enough data to send in order to saturate the link.
|
||||
isAppLimited bool
|
||||
// Total number of sent bytes at the time the packet was sent.
|
||||
// Includes the packet itself.
|
||||
totalBytesSent congestion.ByteCount
|
||||
// Total number of acked bytes at the time the packet was sent.
|
||||
totalBytesAcked congestion.ByteCount
|
||||
// Total number of lost bytes at the time the packet was sent.
|
||||
totalBytesLost congestion.ByteCount
|
||||
// Total number of inflight bytes at the time the packet was sent.
|
||||
// Includes the packet itself.
|
||||
// It should be equal to |total_bytes_sent| minus the sum of
|
||||
// |total_bytes_acked|, |total_bytes_lost| and total neutered bytes.
|
||||
bytesInFlight congestion.ByteCount
|
||||
}
|
||||
|
||||
func newSendTimeState(
|
||||
isAppLimited bool,
|
||||
totalBytesSent congestion.ByteCount,
|
||||
totalBytesAcked congestion.ByteCount,
|
||||
totalBytesLost congestion.ByteCount,
|
||||
bytesInFlight congestion.ByteCount,
|
||||
) *sendTimeState {
|
||||
return &sendTimeState{
|
||||
isValid: true,
|
||||
isAppLimited: isAppLimited,
|
||||
totalBytesSent: totalBytesSent,
|
||||
totalBytesAcked: totalBytesAcked,
|
||||
totalBytesLost: totalBytesLost,
|
||||
bytesInFlight: bytesInFlight,
|
||||
}
|
||||
}
|
||||
|
||||
type extraAckedEvent struct {
|
||||
// The excess bytes acknowlwedged in the time delta for this event.
|
||||
extraAcked congestion.ByteCount
|
||||
|
||||
// The bytes acknowledged and time delta from the event.
|
||||
bytesAcked congestion.ByteCount
|
||||
timeDelta time.Duration
|
||||
// The round trip of the event.
|
||||
round roundTripCount
|
||||
}
|
||||
|
||||
func maxExtraAckedEventFunc(a, b extraAckedEvent) int {
|
||||
if a.extraAcked > b.extraAcked {
|
||||
return 1
|
||||
} else if a.extraAcked < b.extraAcked {
|
||||
return -1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// BandwidthSample
|
||||
type bandwidthSample struct {
|
||||
// The bandwidth at that particular sample. Zero if no valid bandwidth sample
|
||||
// is available.
|
||||
bandwidth Bandwidth
|
||||
// The RTT measurement at this particular sample. Zero if no RTT sample is
|
||||
// available. Does not correct for delayed ack time.
|
||||
rtt time.Duration
|
||||
// |send_rate| is computed from the current packet being acked('P') and an
|
||||
// earlier packet that is acked before P was sent.
|
||||
sendRate Bandwidth
|
||||
// States captured when the packet was sent.
|
||||
stateAtSend sendTimeState
|
||||
}
|
||||
|
||||
func newBandwidthSample() *bandwidthSample {
|
||||
return &bandwidthSample{
|
||||
sendRate: infBandwidth,
|
||||
}
|
||||
}
|
||||
|
||||
// MaxAckHeightTracker is part of the BandwidthSampler. It is called after every
|
||||
// ack event to keep track the degree of ack aggregation(a.k.a "ack height").
|
||||
type maxAckHeightTracker struct {
|
||||
// Tracks the maximum number of bytes acked faster than the estimated
|
||||
// bandwidth.
|
||||
maxAckHeightFilter *WindowedFilter[extraAckedEvent, roundTripCount]
|
||||
// The time this aggregation started and the number of bytes acked during it.
|
||||
aggregationEpochStartTime time.Time
|
||||
aggregationEpochBytes congestion.ByteCount
|
||||
// The last sent packet number before the current aggregation epoch started.
|
||||
lastSentPacketNumberBeforeEpoch congestion.PacketNumber
|
||||
// The number of ack aggregation epochs ever started, including the ongoing
|
||||
// one. Stats only.
|
||||
numAckAggregationEpochs uint64
|
||||
ackAggregationBandwidthThreshold float64
|
||||
startNewAggregationEpochAfterFullRound bool
|
||||
reduceExtraAckedOnBandwidthIncrease bool
|
||||
}
|
||||
|
||||
func newMaxAckHeightTracker(windowLength roundTripCount) *maxAckHeightTracker {
|
||||
return &maxAckHeightTracker{
|
||||
maxAckHeightFilter: NewWindowedFilter(windowLength, maxExtraAckedEventFunc),
|
||||
lastSentPacketNumberBeforeEpoch: invalidPacketNumber,
|
||||
ackAggregationBandwidthThreshold: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *maxAckHeightTracker) Get() congestion.ByteCount {
|
||||
return m.maxAckHeightFilter.GetBest().extraAcked
|
||||
}
|
||||
|
||||
func (m *maxAckHeightTracker) Update(
|
||||
bandwidthEstimate Bandwidth,
|
||||
isNewMaxBandwidth bool,
|
||||
roundTripCount roundTripCount,
|
||||
lastSentPacketNumber congestion.PacketNumber,
|
||||
lastAckedPacketNumber congestion.PacketNumber,
|
||||
ackTime time.Time,
|
||||
bytesAcked congestion.ByteCount,
|
||||
) congestion.ByteCount {
|
||||
forceNewEpoch := false
|
||||
|
||||
if m.reduceExtraAckedOnBandwidthIncrease && isNewMaxBandwidth {
|
||||
// Save and clear existing entries.
|
||||
best := m.maxAckHeightFilter.GetBest()
|
||||
secondBest := m.maxAckHeightFilter.GetSecondBest()
|
||||
thirdBest := m.maxAckHeightFilter.GetThirdBest()
|
||||
m.maxAckHeightFilter.Clear()
|
||||
|
||||
// Reinsert the heights into the filter after recalculating.
|
||||
expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, best.timeDelta)
|
||||
if expectedBytesAcked < best.bytesAcked {
|
||||
best.extraAcked = best.bytesAcked - expectedBytesAcked
|
||||
m.maxAckHeightFilter.Update(best, best.round)
|
||||
}
|
||||
expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, secondBest.timeDelta)
|
||||
if expectedBytesAcked < secondBest.bytesAcked {
|
||||
secondBest.extraAcked = secondBest.bytesAcked - expectedBytesAcked
|
||||
m.maxAckHeightFilter.Update(secondBest, secondBest.round)
|
||||
}
|
||||
expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, thirdBest.timeDelta)
|
||||
if expectedBytesAcked < thirdBest.bytesAcked {
|
||||
thirdBest.extraAcked = thirdBest.bytesAcked - expectedBytesAcked
|
||||
m.maxAckHeightFilter.Update(thirdBest, thirdBest.round)
|
||||
}
|
||||
}
|
||||
|
||||
// If any packet sent after the start of the epoch has been acked, start a new
|
||||
// epoch.
|
||||
if m.startNewAggregationEpochAfterFullRound &&
|
||||
m.lastSentPacketNumberBeforeEpoch != invalidPacketNumber &&
|
||||
lastAckedPacketNumber != invalidPacketNumber &&
|
||||
lastAckedPacketNumber > m.lastSentPacketNumberBeforeEpoch {
|
||||
forceNewEpoch = true
|
||||
}
|
||||
if m.aggregationEpochStartTime.IsZero() || forceNewEpoch {
|
||||
m.aggregationEpochBytes = bytesAcked
|
||||
m.aggregationEpochStartTime = ackTime
|
||||
m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber
|
||||
m.numAckAggregationEpochs++
|
||||
return 0
|
||||
}
|
||||
|
||||
// Compute how many bytes are expected to be delivered, assuming max bandwidth
|
||||
// is correct.
|
||||
aggregationDelta := ackTime.Sub(m.aggregationEpochStartTime)
|
||||
expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, aggregationDelta)
|
||||
// Reset the current aggregation epoch as soon as the ack arrival rate is less
|
||||
// than or equal to the max bandwidth.
|
||||
if m.aggregationEpochBytes <= congestion.ByteCount(m.ackAggregationBandwidthThreshold*float64(expectedBytesAcked)) {
|
||||
// Reset to start measuring a new aggregation epoch.
|
||||
m.aggregationEpochBytes = bytesAcked
|
||||
m.aggregationEpochStartTime = ackTime
|
||||
m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber
|
||||
m.numAckAggregationEpochs++
|
||||
return 0
|
||||
}
|
||||
|
||||
m.aggregationEpochBytes += bytesAcked
|
||||
|
||||
// Compute how many extra bytes were delivered vs max bandwidth.
|
||||
extraBytesAcked := m.aggregationEpochBytes - expectedBytesAcked
|
||||
newEvent := extraAckedEvent{
|
||||
extraAcked: expectedBytesAcked,
|
||||
bytesAcked: m.aggregationEpochBytes,
|
||||
timeDelta: aggregationDelta,
|
||||
}
|
||||
m.maxAckHeightFilter.Update(newEvent, roundTripCount)
|
||||
return extraBytesAcked
|
||||
}
|
||||
|
||||
func (m *maxAckHeightTracker) SetFilterWindowLength(length roundTripCount) {
|
||||
m.maxAckHeightFilter.SetWindowLength(length)
|
||||
}
|
||||
|
||||
func (m *maxAckHeightTracker) Reset(newHeight congestion.ByteCount, newTime roundTripCount) {
|
||||
newEvent := extraAckedEvent{
|
||||
extraAcked: newHeight,
|
||||
round: newTime,
|
||||
}
|
||||
m.maxAckHeightFilter.Reset(newEvent, newTime)
|
||||
}
|
||||
|
||||
func (m *maxAckHeightTracker) SetAckAggregationBandwidthThreshold(threshold float64) {
|
||||
m.ackAggregationBandwidthThreshold = threshold
|
||||
}
|
||||
|
||||
func (m *maxAckHeightTracker) SetStartNewAggregationEpochAfterFullRound(value bool) {
|
||||
m.startNewAggregationEpochAfterFullRound = value
|
||||
}
|
||||
|
||||
func (m *maxAckHeightTracker) SetReduceExtraAckedOnBandwidthIncrease(value bool) {
|
||||
m.reduceExtraAckedOnBandwidthIncrease = value
|
||||
}
|
||||
|
||||
func (m *maxAckHeightTracker) AckAggregationBandwidthThreshold() float64 {
|
||||
return m.ackAggregationBandwidthThreshold
|
||||
}
|
||||
|
||||
func (m *maxAckHeightTracker) NumAckAggregationEpochs() uint64 {
|
||||
return m.numAckAggregationEpochs
|
||||
}
|
||||
|
||||
// AckPoint represents a point on the ack line.
|
||||
type ackPoint struct {
|
||||
ackTime time.Time
|
||||
totalBytesAcked congestion.ByteCount
|
||||
}
|
||||
|
||||
// RecentAckPoints maintains the most recent 2 ack points at distinct times.
|
||||
type recentAckPoints struct {
|
||||
ackPoints [2]ackPoint
|
||||
}
|
||||
|
||||
func (r *recentAckPoints) Update(ackTime time.Time, totalBytesAcked congestion.ByteCount) {
|
||||
if ackTime.Before(r.ackPoints[1].ackTime) {
|
||||
r.ackPoints[1].ackTime = ackTime
|
||||
} else if ackTime.After(r.ackPoints[1].ackTime) {
|
||||
r.ackPoints[0] = r.ackPoints[1]
|
||||
r.ackPoints[1].ackTime = ackTime
|
||||
}
|
||||
|
||||
r.ackPoints[1].totalBytesAcked = totalBytesAcked
|
||||
}
|
||||
|
||||
func (r *recentAckPoints) Clear() {
|
||||
r.ackPoints[0] = ackPoint{}
|
||||
r.ackPoints[1] = ackPoint{}
|
||||
}
|
||||
|
||||
func (r *recentAckPoints) MostRecentPoint() *ackPoint {
|
||||
return &r.ackPoints[1]
|
||||
}
|
||||
|
||||
func (r *recentAckPoints) LessRecentPoint() *ackPoint {
|
||||
if r.ackPoints[0].totalBytesAcked != 0 {
|
||||
return &r.ackPoints[0]
|
||||
}
|
||||
|
||||
return &r.ackPoints[1]
|
||||
}
|
||||
|
||||
// ConnectionStateOnSentPacket represents the information about a sent packet
|
||||
// and the state of the connection at the moment the packet was sent,
|
||||
// specifically the information about the most recently acknowledged packet at
|
||||
// that moment.
|
||||
type connectionStateOnSentPacket struct {
|
||||
// Time at which the packet is sent.
|
||||
sentTime time.Time
|
||||
// Size of the packet.
|
||||
size congestion.ByteCount
|
||||
// The value of |totalBytesSentAtLastAckedPacket| at the time the
|
||||
// packet was sent.
|
||||
totalBytesSentAtLastAckedPacket congestion.ByteCount
|
||||
// The value of |lastAckedPacketSentTime| at the time the packet was
|
||||
// sent.
|
||||
lastAckedPacketSentTime time.Time
|
||||
// The value of |lastAckedPacketAckTime| at the time the packet was
|
||||
// sent.
|
||||
lastAckedPacketAckTime time.Time
|
||||
// Send time states that are returned to the congestion controller when the
|
||||
// packet is acked or lost.
|
||||
sendTimeState sendTimeState
|
||||
}
|
||||
|
||||
// Snapshot constructor. Records the current state of the bandwidth
|
||||
// sampler.
|
||||
// |bytes_in_flight| is the bytes in flight right after the packet is sent.
|
||||
func newConnectionStateOnSentPacket(
|
||||
sentTime time.Time,
|
||||
size congestion.ByteCount,
|
||||
bytesInFlight congestion.ByteCount,
|
||||
sampler *bandwidthSampler,
|
||||
) *connectionStateOnSentPacket {
|
||||
return &connectionStateOnSentPacket{
|
||||
sentTime: sentTime,
|
||||
size: size,
|
||||
totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket,
|
||||
lastAckedPacketSentTime: sampler.lastAckedPacketSentTime,
|
||||
lastAckedPacketAckTime: sampler.lastAckedPacketAckTime,
|
||||
sendTimeState: *newSendTimeState(
|
||||
sampler.isAppLimited,
|
||||
sampler.totalBytesSent,
|
||||
sampler.totalBytesAcked,
|
||||
sampler.totalBytesLost,
|
||||
bytesInFlight,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// BandwidthSampler keeps track of sent and acknowledged packets and outputs a
|
||||
// bandwidth sample for every packet acknowledged. The samples are taken for
|
||||
// individual packets, and are not filtered; the consumer has to filter the
|
||||
// bandwidth samples itself. In certain cases, the sampler will locally severely
|
||||
// underestimate the bandwidth, hence a maximum filter with a size of at least
|
||||
// one RTT is recommended.
|
||||
//
|
||||
// This class bases its samples on the slope of two curves: the number of bytes
|
||||
// sent over time, and the number of bytes acknowledged as received over time.
|
||||
// It produces a sample of both slopes for every packet that gets acknowledged,
|
||||
// based on a slope between two points on each of the corresponding curves. Note
|
||||
// that due to the packet loss, the number of bytes on each curve might get
|
||||
// further and further away from each other, meaning that it is not feasible to
|
||||
// compare byte values coming from different curves with each other.
|
||||
//
|
||||
// The obvious points for measuring slope sample are the ones corresponding to
|
||||
// the packet that was just acknowledged. Let us denote them as S_1 (point at
|
||||
// which the current packet was sent) and A_1 (point at which the current packet
|
||||
// was acknowledged). However, taking a slope requires two points on each line,
|
||||
// so estimating bandwidth requires picking a packet in the past with respect to
|
||||
// which the slope is measured.
|
||||
//
|
||||
// For that purpose, BandwidthSampler always keeps track of the most recently
|
||||
// acknowledged packet, and records it together with every outgoing packet.
|
||||
// When a packet gets acknowledged (A_1), it has not only information about when
|
||||
// it itself was sent (S_1), but also the information about the latest
|
||||
// acknowledged packet right before it was sent (S_0 and A_0).
|
||||
//
|
||||
// Based on that data, send and ack rate are estimated as:
|
||||
//
|
||||
// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0))
|
||||
// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0))
|
||||
//
|
||||
// Here, the ack rate is intuitively the rate we want to treat as bandwidth.
|
||||
// However, in certain cases (e.g. ack compression) the ack rate at a point may
|
||||
// end up higher than the rate at which the data was originally sent, which is
|
||||
// not indicative of the real bandwidth. Hence, we use the send rate as an upper
|
||||
// bound, and the sample value is
|
||||
//
|
||||
// rate_sample = min(send_rate, ack_rate)
|
||||
//
|
||||
// An important edge case handled by the sampler is tracking the app-limited
|
||||
// samples. There are multiple meaning of "app-limited" used interchangeably,
|
||||
// hence it is important to understand and to be able to distinguish between
|
||||
// them.
|
||||
//
|
||||
// Meaning 1: connection state. The connection is said to be app-limited when
|
||||
// there is no outstanding data to send. This means that certain bandwidth
|
||||
// samples in the future would not be an accurate indication of the link
|
||||
// capacity, and it is important to inform consumer about that. Whenever
|
||||
// connection becomes app-limited, the sampler is notified via OnAppLimited()
|
||||
// method.
|
||||
//
|
||||
// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth
|
||||
// sampler becomes notified about the connection being app-limited, it enters
|
||||
// app-limited phase. In that phase, all *sent* packets are marked as
|
||||
// app-limited. Note that the connection itself does not have to be
|
||||
// app-limited during the app-limited phase, and in fact it will not be
|
||||
// (otherwise how would it send packets?). The boolean flag below indicates
|
||||
// whether the sampler is in that phase.
|
||||
//
|
||||
// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is
|
||||
// sent during the app-limited phase, the resulting sample related to the
|
||||
// packet will be marked as app-limited.
|
||||
//
|
||||
// With the terminology issue out of the way, let us consider the question of
|
||||
// what kind of situation it addresses.
|
||||
//
|
||||
// Consider a scenario where we first send packets 1 to 20 at a regular
|
||||
// bandwidth, and then immediately run out of data. After a few seconds, we send
|
||||
// packets 21 to 60, and only receive ack for 21 between sending packets 40 and
|
||||
// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0
|
||||
// we use to compute the slope is going to be packet 20, a few seconds apart
|
||||
// from the current packet, hence the resulting estimate would be extremely low
|
||||
// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21,
|
||||
// meaning that the bandwidth sample would exclude the quiescence.
|
||||
//
|
||||
// Based on the analysis of that scenario, we implement the following rule: once
|
||||
// OnAppLimited() is called, all sent packets will produce app-limited samples
|
||||
// up until an ack for a packet that was sent after OnAppLimited() was called.
|
||||
// Note that while the scenario above is not the only scenario when the
|
||||
// connection is app-limited, the approach works in other cases too.
|
||||
|
||||
type congestionEventSample struct {
|
||||
// The maximum bandwidth sample from all acked packets.
|
||||
// QuicBandwidth::Zero() if no samples are available.
|
||||
sampleMaxBandwidth Bandwidth
|
||||
// Whether |sample_max_bandwidth| is from a app-limited sample.
|
||||
sampleIsAppLimited bool
|
||||
// The minimum rtt sample from all acked packets.
|
||||
// QuicTime::Delta::Infinite() if no samples are available.
|
||||
sampleRtt time.Duration
|
||||
// For each packet p in acked packets, this is the max value of INFLIGHT(p),
|
||||
// where INFLIGHT(p) is the number of bytes acked while p is inflight.
|
||||
sampleMaxInflight congestion.ByteCount
|
||||
// The send state of the largest packet in acked_packets, unless it is
|
||||
// empty. If acked_packets is empty, it's the send state of the largest
|
||||
// packet in lost_packets.
|
||||
lastPacketSendState sendTimeState
|
||||
// The number of extra bytes acked from this ack event, compared to what is
|
||||
// expected from the flow's bandwidth. Larger value means more ack
|
||||
// aggregation.
|
||||
extraAcked congestion.ByteCount
|
||||
}
|
||||
|
||||
func newCongestionEventSample() *congestionEventSample {
|
||||
return &congestionEventSample{
|
||||
sampleRtt: infRTT,
|
||||
}
|
||||
}
|
||||
|
||||
type bandwidthSampler struct {
|
||||
// The total number of congestion controlled bytes sent during the connection.
|
||||
totalBytesSent congestion.ByteCount
|
||||
|
||||
// The total number of congestion controlled bytes which were acknowledged.
|
||||
totalBytesAcked congestion.ByteCount
|
||||
|
||||
// The total number of congestion controlled bytes which were lost.
|
||||
totalBytesLost congestion.ByteCount
|
||||
|
||||
// The total number of congestion controlled bytes which have been neutered.
|
||||
totalBytesNeutered congestion.ByteCount
|
||||
|
||||
// The value of |total_bytes_sent_| at the time the last acknowledged packet
|
||||
// was sent. Valid only when |last_acked_packet_sent_time_| is valid.
|
||||
totalBytesSentAtLastAckedPacket congestion.ByteCount
|
||||
|
||||
// The time at which the last acknowledged packet was sent. Set to
|
||||
// QuicTime::Zero() if no valid timestamp is available.
|
||||
lastAckedPacketSentTime time.Time
|
||||
|
||||
// The time at which the most recent packet was acknowledged.
|
||||
lastAckedPacketAckTime time.Time
|
||||
|
||||
// The most recently sent packet.
|
||||
lastSentPacket congestion.PacketNumber
|
||||
|
||||
// The most recently acked packet.
|
||||
lastAckedPacket congestion.PacketNumber
|
||||
|
||||
// Indicates whether the bandwidth sampler is currently in an app-limited
|
||||
// phase.
|
||||
isAppLimited bool
|
||||
|
||||
// The packet that will be acknowledged after this one will cause the sampler
|
||||
// to exit the app-limited phase.
|
||||
endOfAppLimitedPhase congestion.PacketNumber
|
||||
|
||||
// Record of the connection state at the point where each packet in flight was
|
||||
// sent, indexed by the packet number.
|
||||
connectionStateMap *packetNumberIndexedQueue[connectionStateOnSentPacket]
|
||||
|
||||
recentAckPoints recentAckPoints
|
||||
a0Candidates RingBuffer[ackPoint]
|
||||
|
||||
// Maximum number of tracked packets.
|
||||
maxTrackedPackets congestion.ByteCount
|
||||
|
||||
maxAckHeightTracker *maxAckHeightTracker
|
||||
totalBytesAckedAfterLastAckEvent congestion.ByteCount
|
||||
|
||||
// True if connection option 'BSAO' is set.
|
||||
overestimateAvoidance bool
|
||||
|
||||
// True if connection option 'BBRB' is set.
|
||||
limitMaxAckHeightTrackerBySendRate bool
|
||||
}
|
||||
|
||||
func newBandwidthSampler(maxAckHeightTrackerWindowLength roundTripCount) *bandwidthSampler {
|
||||
b := &bandwidthSampler{
|
||||
maxAckHeightTracker: newMaxAckHeightTracker(maxAckHeightTrackerWindowLength),
|
||||
connectionStateMap: newPacketNumberIndexedQueue[connectionStateOnSentPacket](defaultConnectionStateMapQueueSize),
|
||||
lastSentPacket: invalidPacketNumber,
|
||||
lastAckedPacket: invalidPacketNumber,
|
||||
endOfAppLimitedPhase: invalidPacketNumber,
|
||||
}
|
||||
|
||||
b.a0Candidates.Init(defaultCandidatesBufferSize)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) MaxAckHeight() congestion.ByteCount {
|
||||
return b.maxAckHeightTracker.Get()
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) NumAckAggregationEpochs() uint64 {
|
||||
return b.maxAckHeightTracker.NumAckAggregationEpochs()
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) SetMaxAckHeightTrackerWindowLength(length roundTripCount) {
|
||||
b.maxAckHeightTracker.SetFilterWindowLength(length)
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) ResetMaxAckHeightTracker(newHeight congestion.ByteCount, newTime roundTripCount) {
|
||||
b.maxAckHeightTracker.Reset(newHeight, newTime)
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) SetStartNewAggregationEpochAfterFullRound(value bool) {
|
||||
b.maxAckHeightTracker.SetStartNewAggregationEpochAfterFullRound(value)
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) SetLimitMaxAckHeightTrackerBySendRate(value bool) {
|
||||
b.limitMaxAckHeightTrackerBySendRate = value
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) SetReduceExtraAckedOnBandwidthIncrease(value bool) {
|
||||
b.maxAckHeightTracker.SetReduceExtraAckedOnBandwidthIncrease(value)
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) EnableOverestimateAvoidance() {
|
||||
if b.overestimateAvoidance {
|
||||
return
|
||||
}
|
||||
|
||||
b.overestimateAvoidance = true
|
||||
b.maxAckHeightTracker.SetAckAggregationBandwidthThreshold(2.0)
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) IsOverestimateAvoidanceEnabled() bool {
|
||||
return b.overestimateAvoidance
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) OnPacketSent(
|
||||
sentTime time.Time,
|
||||
packetNumber congestion.PacketNumber,
|
||||
bytes congestion.ByteCount,
|
||||
bytesInFlight congestion.ByteCount,
|
||||
isRetransmittable bool,
|
||||
) {
|
||||
b.lastSentPacket = packetNumber
|
||||
|
||||
if !isRetransmittable {
|
||||
return
|
||||
}
|
||||
|
||||
b.totalBytesSent += bytes
|
||||
|
||||
// If there are no packets in flight, the time at which the new transmission
|
||||
// opens can be treated as the A_0 point for the purpose of bandwidth
|
||||
// sampling. This underestimates bandwidth to some extent, and produces some
|
||||
// artificially low samples for most packets in flight, but it provides with
|
||||
// samples at important points where we would not have them otherwise, most
|
||||
// importantly at the beginning of the connection.
|
||||
if bytesInFlight == 0 {
|
||||
b.lastAckedPacketAckTime = sentTime
|
||||
if b.overestimateAvoidance {
|
||||
b.recentAckPoints.Clear()
|
||||
b.recentAckPoints.Update(sentTime, b.totalBytesAcked)
|
||||
b.a0Candidates.Clear()
|
||||
b.a0Candidates.PushBack(*b.recentAckPoints.MostRecentPoint())
|
||||
}
|
||||
b.totalBytesSentAtLastAckedPacket = b.totalBytesSent
|
||||
|
||||
// In this situation ack compression is not a concern, set send rate to
|
||||
// effectively infinite.
|
||||
b.lastAckedPacketSentTime = sentTime
|
||||
}
|
||||
|
||||
b.connectionStateMap.Emplace(packetNumber, newConnectionStateOnSentPacket(
|
||||
sentTime,
|
||||
bytes,
|
||||
bytesInFlight+bytes,
|
||||
b,
|
||||
))
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) OnCongestionEvent(
|
||||
ackTime time.Time,
|
||||
ackedPackets []congestion.AckedPacketInfo,
|
||||
lostPackets []congestion.LostPacketInfo,
|
||||
maxBandwidth Bandwidth,
|
||||
estBandwidthUpperBound Bandwidth,
|
||||
roundTripCount roundTripCount,
|
||||
) congestionEventSample {
|
||||
eventSample := newCongestionEventSample()
|
||||
|
||||
var lastLostPacketSendState sendTimeState
|
||||
|
||||
for _, p := range lostPackets {
|
||||
sendState := b.OnPacketLost(p.PacketNumber, p.BytesLost)
|
||||
if sendState.isValid {
|
||||
lastLostPacketSendState = sendState
|
||||
}
|
||||
}
|
||||
|
||||
if len(ackedPackets) == 0 {
|
||||
// Only populate send state for a loss-only event.
|
||||
eventSample.lastPacketSendState = lastLostPacketSendState
|
||||
return *eventSample
|
||||
}
|
||||
|
||||
var lastAckedPacketSendState sendTimeState
|
||||
var maxSendRate Bandwidth
|
||||
|
||||
for _, p := range ackedPackets {
|
||||
sample := b.onPacketAcknowledged(ackTime, p.PacketNumber)
|
||||
if !sample.stateAtSend.isValid {
|
||||
continue
|
||||
}
|
||||
|
||||
lastAckedPacketSendState = sample.stateAtSend
|
||||
|
||||
if sample.rtt != 0 {
|
||||
eventSample.sampleRtt = min(eventSample.sampleRtt, sample.rtt)
|
||||
}
|
||||
if sample.bandwidth > eventSample.sampleMaxBandwidth {
|
||||
eventSample.sampleMaxBandwidth = sample.bandwidth
|
||||
eventSample.sampleIsAppLimited = sample.stateAtSend.isAppLimited
|
||||
}
|
||||
if sample.sendRate != infBandwidth {
|
||||
maxSendRate = max(maxSendRate, sample.sendRate)
|
||||
}
|
||||
inflightSample := b.totalBytesAcked - lastAckedPacketSendState.totalBytesAcked
|
||||
if inflightSample > eventSample.sampleMaxInflight {
|
||||
eventSample.sampleMaxInflight = inflightSample
|
||||
}
|
||||
}
|
||||
|
||||
if !lastLostPacketSendState.isValid {
|
||||
eventSample.lastPacketSendState = lastAckedPacketSendState
|
||||
} else if !lastAckedPacketSendState.isValid {
|
||||
eventSample.lastPacketSendState = lastLostPacketSendState
|
||||
} else {
|
||||
// If two packets are inflight and an alarm is armed to lose a packet and it
|
||||
// wakes up late, then the first of two in flight packets could have been
|
||||
// acknowledged before the wakeup, which re-evaluates loss detection, and
|
||||
// could declare the later of the two lost.
|
||||
if lostPackets[len(lostPackets)-1].PacketNumber > ackedPackets[len(ackedPackets)-1].PacketNumber {
|
||||
eventSample.lastPacketSendState = lastLostPacketSendState
|
||||
} else {
|
||||
eventSample.lastPacketSendState = lastAckedPacketSendState
|
||||
}
|
||||
}
|
||||
|
||||
isNewMaxBandwidth := eventSample.sampleMaxBandwidth > maxBandwidth
|
||||
maxBandwidth = max(maxBandwidth, eventSample.sampleMaxBandwidth)
|
||||
if b.limitMaxAckHeightTrackerBySendRate {
|
||||
maxBandwidth = max(maxBandwidth, maxSendRate)
|
||||
}
|
||||
|
||||
eventSample.extraAcked = b.onAckEventEnd(min(estBandwidthUpperBound, maxBandwidth), isNewMaxBandwidth, roundTripCount)
|
||||
|
||||
return *eventSample
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber, bytesLost congestion.ByteCount) (s sendTimeState) {
|
||||
b.totalBytesLost += bytesLost
|
||||
if sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber); sentPacketPointer != nil {
|
||||
sentPacketToSendTimeState(sentPacketPointer, &s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) OnPacketNeutered(packetNumber congestion.PacketNumber) {
|
||||
b.connectionStateMap.Remove(packetNumber, func(sentPacket connectionStateOnSentPacket) {
|
||||
b.totalBytesNeutered += sentPacket.size
|
||||
})
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) OnAppLimited() {
|
||||
b.isAppLimited = true
|
||||
b.endOfAppLimitedPhase = b.lastSentPacket
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) RemoveObsoletePackets(leastUnacked congestion.PacketNumber) {
|
||||
// A packet can become obsolete when it is removed from QuicUnackedPacketMap's
|
||||
// view of inflight before it is acked or marked as lost. For example, when
|
||||
// QuicSentPacketManager::RetransmitCryptoPackets retransmits a crypto packet,
|
||||
// the packet is removed from QuicUnackedPacketMap's inflight, but is not
|
||||
// marked as acked or lost in the BandwidthSampler.
|
||||
b.connectionStateMap.RemoveUpTo(leastUnacked)
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) TotalBytesSent() congestion.ByteCount {
|
||||
return b.totalBytesSent
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) TotalBytesLost() congestion.ByteCount {
|
||||
return b.totalBytesLost
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) TotalBytesAcked() congestion.ByteCount {
|
||||
return b.totalBytesAcked
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) TotalBytesNeutered() congestion.ByteCount {
|
||||
return b.totalBytesNeutered
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) IsAppLimited() bool {
|
||||
return b.isAppLimited
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) EndOfAppLimitedPhase() congestion.PacketNumber {
|
||||
return b.endOfAppLimitedPhase
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) max_ack_height() congestion.ByteCount {
|
||||
return b.maxAckHeightTracker.Get()
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) chooseA0Point(totalBytesAcked congestion.ByteCount, a0 *ackPoint) bool {
|
||||
if b.a0Candidates.Empty() {
|
||||
return false
|
||||
}
|
||||
|
||||
if b.a0Candidates.Len() == 1 {
|
||||
*a0 = *b.a0Candidates.Front()
|
||||
return true
|
||||
}
|
||||
|
||||
for i := 1; i < b.a0Candidates.Len(); i++ {
|
||||
if b.a0Candidates.Offset(i).totalBytesAcked > totalBytesAcked {
|
||||
*a0 = *b.a0Candidates.Offset(i - 1)
|
||||
if i > 1 {
|
||||
for j := 0; j < i-1; j++ {
|
||||
b.a0Candidates.PopFront()
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
*a0 = *b.a0Candidates.Back()
|
||||
for k := 0; k < b.a0Candidates.Len()-1; k++ {
|
||||
b.a0Candidates.PopFront()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) onPacketAcknowledged(ackTime time.Time, packetNumber congestion.PacketNumber) bandwidthSample {
|
||||
sample := newBandwidthSample()
|
||||
b.lastAckedPacket = packetNumber
|
||||
sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber)
|
||||
if sentPacketPointer == nil {
|
||||
return *sample
|
||||
}
|
||||
|
||||
// OnPacketAcknowledgedInner
|
||||
b.totalBytesAcked += sentPacketPointer.size
|
||||
b.totalBytesSentAtLastAckedPacket = sentPacketPointer.sendTimeState.totalBytesSent
|
||||
b.lastAckedPacketSentTime = sentPacketPointer.sentTime
|
||||
b.lastAckedPacketAckTime = ackTime
|
||||
if b.overestimateAvoidance {
|
||||
b.recentAckPoints.Update(ackTime, b.totalBytesAcked)
|
||||
}
|
||||
|
||||
if b.isAppLimited {
|
||||
// Exit app-limited phase in two cases:
|
||||
// (1) end_of_app_limited_phase_ is not initialized, i.e., so far all
|
||||
// packets are sent while there are buffered packets or pending data.
|
||||
// (2) The current acked packet is after the sent packet marked as the end
|
||||
// of the app limit phase.
|
||||
if b.endOfAppLimitedPhase == invalidPacketNumber ||
|
||||
packetNumber > b.endOfAppLimitedPhase {
|
||||
b.isAppLimited = false
|
||||
}
|
||||
}
|
||||
|
||||
// There might have been no packets acknowledged at the moment when the
|
||||
// current packet was sent. In that case, there is no bandwidth sample to
|
||||
// make.
|
||||
if sentPacketPointer.lastAckedPacketSentTime.IsZero() {
|
||||
return *sample
|
||||
}
|
||||
|
||||
// Infinite rate indicates that the sampler is supposed to discard the
|
||||
// current send rate sample and use only the ack rate.
|
||||
sendRate := infBandwidth
|
||||
if sentPacketPointer.sentTime.After(sentPacketPointer.lastAckedPacketSentTime) {
|
||||
sendRate = BandwidthFromDelta(
|
||||
sentPacketPointer.sendTimeState.totalBytesSent-sentPacketPointer.totalBytesSentAtLastAckedPacket,
|
||||
sentPacketPointer.sentTime.Sub(sentPacketPointer.lastAckedPacketSentTime))
|
||||
}
|
||||
|
||||
var a0 ackPoint
|
||||
if b.overestimateAvoidance && b.chooseA0Point(sentPacketPointer.sendTimeState.totalBytesAcked, &a0) {
|
||||
} else {
|
||||
a0.ackTime = sentPacketPointer.lastAckedPacketAckTime
|
||||
a0.totalBytesAcked = sentPacketPointer.sendTimeState.totalBytesAcked
|
||||
}
|
||||
|
||||
// During the slope calculation, ensure that ack time of the current packet is
|
||||
// always larger than the time of the previous packet, otherwise division by
|
||||
// zero or integer underflow can occur.
|
||||
if ackTime.Sub(a0.ackTime) <= 0 {
|
||||
return *sample
|
||||
}
|
||||
|
||||
ackRate := BandwidthFromDelta(b.totalBytesAcked-a0.totalBytesAcked, ackTime.Sub(a0.ackTime))
|
||||
|
||||
sample.bandwidth = min(sendRate, ackRate)
|
||||
// Note: this sample does not account for delayed acknowledgement time. This
|
||||
// means that the RTT measurements here can be artificially high, especially
|
||||
// on low bandwidth connections.
|
||||
sample.rtt = ackTime.Sub(sentPacketPointer.sentTime)
|
||||
sample.sendRate = sendRate
|
||||
sentPacketToSendTimeState(sentPacketPointer, &sample.stateAtSend)
|
||||
|
||||
return *sample
|
||||
}
|
||||
|
||||
func (b *bandwidthSampler) onAckEventEnd(
|
||||
bandwidthEstimate Bandwidth,
|
||||
isNewMaxBandwidth bool,
|
||||
roundTripCount roundTripCount,
|
||||
) congestion.ByteCount {
|
||||
newlyAckedBytes := b.totalBytesAcked - b.totalBytesAckedAfterLastAckEvent
|
||||
if newlyAckedBytes == 0 {
|
||||
return 0
|
||||
}
|
||||
b.totalBytesAckedAfterLastAckEvent = b.totalBytesAcked
|
||||
extraAcked := b.maxAckHeightTracker.Update(
|
||||
bandwidthEstimate,
|
||||
isNewMaxBandwidth,
|
||||
roundTripCount,
|
||||
b.lastSentPacket,
|
||||
b.lastAckedPacket,
|
||||
b.lastAckedPacketAckTime,
|
||||
newlyAckedBytes)
|
||||
// If |extra_acked| is zero, i.e. this ack event marks the start of a new ack
|
||||
// aggregation epoch, save LessRecentPoint, which is the last ack point of the
|
||||
// previous epoch, as a A0 candidate.
|
||||
if b.overestimateAvoidance && extraAcked == 0 {
|
||||
b.a0Candidates.PushBack(*b.recentAckPoints.LessRecentPoint())
|
||||
}
|
||||
return extraAcked
|
||||
}
|
||||
|
||||
func sentPacketToSendTimeState(sentPacket *connectionStateOnSentPacket, sendTimeState *sendTimeState) {
|
||||
*sendTimeState = sentPacket.sendTimeState
|
||||
sendTimeState.isValid = true
|
||||
}
|
||||
|
||||
// BytesFromBandwidthAndTimeDelta calculates the bytes
|
||||
// from a bandwidth(bits per second) and a time delta
|
||||
func bytesFromBandwidthAndTimeDelta(bandwidth Bandwidth, delta time.Duration) congestion.ByteCount {
|
||||
return (congestion.ByteCount(bandwidth) * congestion.ByteCount(delta)) /
|
||||
(congestion.ByteCount(time.Second) * 8)
|
||||
}
|
||||
|
||||
func timeDeltaFromBytesAndBandwidth(bytes congestion.ByteCount, bandwidth Bandwidth) time.Duration {
|
||||
return time.Duration(bytes*8) * time.Second / time.Duration(bandwidth)
|
||||
}
|
944
transport/hysteria2/core/internal/congestion/bbr/bbr_sender.go
Normal file
944
transport/hysteria2/core/internal/congestion/bbr/bbr_sender.go
Normal file
|
@ -0,0 +1,944 @@
|
|||
package bbr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/quic-go/congestion"
|
||||
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion/common"
|
||||
)
|
||||
|
||||
// BbrSender implements BBR congestion control algorithm. BBR aims to estimate
|
||||
// the current available Bottleneck Bandwidth and RTT (hence the name), and
|
||||
// regulates the pacing rate and the size of the congestion window based on
|
||||
// those signals.
|
||||
//
|
||||
// BBR relies on pacing in order to function properly. Do not use BBR when
|
||||
// pacing is disabled.
|
||||
//
|
||||
|
||||
const (
|
||||
minBps = 65536 // 64 kbps
|
||||
|
||||
invalidPacketNumber = -1
|
||||
initialCongestionWindowPackets = 32
|
||||
|
||||
// Constants based on TCP defaults.
|
||||
// The minimum CWND to ensure delayed acks don't reduce bandwidth measurements.
|
||||
// Does not inflate the pacing rate.
|
||||
defaultMinimumCongestionWindow = 4 * congestion.ByteCount(congestion.InitialPacketSizeIPv4)
|
||||
|
||||
// The gain used for the STARTUP, equal to 2/ln(2).
|
||||
defaultHighGain = 2.885
|
||||
// The newly derived gain for STARTUP, equal to 4 * ln(2)
|
||||
derivedHighGain = 2.773
|
||||
// The newly derived CWND gain for STARTUP, 2.
|
||||
derivedHighCWNDGain = 2.0
|
||||
)
|
||||
|
||||
// The cycle of gains used during the PROBE_BW stage.
|
||||
var pacingGain = [...]float64{1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}
|
||||
|
||||
const (
|
||||
// The length of the gain cycle.
|
||||
gainCycleLength = len(pacingGain)
|
||||
// The size of the bandwidth filter window, in round-trips.
|
||||
bandwidthWindowSize = gainCycleLength + 2
|
||||
|
||||
// The time after which the current min_rtt value expires.
|
||||
minRttExpiry = 10 * time.Second
|
||||
// The minimum time the connection can spend in PROBE_RTT mode.
|
||||
probeRttTime = 200 * time.Millisecond
|
||||
// If the bandwidth does not increase by the factor of |kStartupGrowthTarget|
|
||||
// within |kRoundTripsWithoutGrowthBeforeExitingStartup| rounds, the connection
|
||||
// will exit the STARTUP mode.
|
||||
startupGrowthTarget = 1.25
|
||||
roundTripsWithoutGrowthBeforeExitingStartup = int64(3)
|
||||
|
||||
// Flag.
|
||||
defaultStartupFullLossCount = 8
|
||||
quicBbr2DefaultLossThreshold = 0.02
|
||||
maxBbrBurstPackets = 3
|
||||
)
|
||||
|
||||
type bbrMode int
|
||||
|
||||
const (
|
||||
// Startup phase of the connection.
|
||||
bbrModeStartup = iota
|
||||
// After achieving the highest possible bandwidth during the startup, lower
|
||||
// the pacing rate in order to drain the queue.
|
||||
bbrModeDrain
|
||||
// Cruising mode.
|
||||
bbrModeProbeBw
|
||||
// Temporarily slow down sending in order to empty the buffer and measure
|
||||
// the real minimum RTT.
|
||||
bbrModeProbeRtt
|
||||
)
|
||||
|
||||
// Indicates how the congestion control limits the amount of bytes in flight.
|
||||
type bbrRecoveryState int
|
||||
|
||||
const (
|
||||
// Do not limit.
|
||||
bbrRecoveryStateNotInRecovery = iota
|
||||
// Allow an extra outstanding byte for each byte acknowledged.
|
||||
bbrRecoveryStateConservation
|
||||
// Allow two extra outstanding bytes for each byte acknowledged (slow
|
||||
// start).
|
||||
bbrRecoveryStateGrowth
|
||||
)
|
||||
|
||||
type bbrSender struct {
|
||||
rttStats congestion.RTTStatsProvider
|
||||
clock Clock
|
||||
pacer *common.Pacer
|
||||
|
||||
mode bbrMode
|
||||
|
||||
// Bandwidth sampler provides BBR with the bandwidth measurements at
|
||||
// individual points.
|
||||
sampler *bandwidthSampler
|
||||
|
||||
// The number of the round trips that have occurred during the connection.
|
||||
roundTripCount roundTripCount
|
||||
|
||||
// The packet number of the most recently sent packet.
|
||||
lastSentPacket congestion.PacketNumber
|
||||
// Acknowledgement of any packet after |current_round_trip_end_| will cause
|
||||
// the round trip counter to advance.
|
||||
currentRoundTripEnd congestion.PacketNumber
|
||||
|
||||
// Number of congestion events with some losses, in the current round.
|
||||
numLossEventsInRound uint64
|
||||
|
||||
// Number of total bytes lost in the current round.
|
||||
bytesLostInRound congestion.ByteCount
|
||||
|
||||
// The filter that tracks the maximum bandwidth over the multiple recent
|
||||
// round-trips.
|
||||
maxBandwidth *WindowedFilter[Bandwidth, roundTripCount]
|
||||
|
||||
// Minimum RTT estimate. Automatically expires within 10 seconds (and
|
||||
// triggers PROBE_RTT mode) if no new value is sampled during that period.
|
||||
minRtt time.Duration
|
||||
// The time at which the current value of |min_rtt_| was assigned.
|
||||
minRttTimestamp time.Time
|
||||
|
||||
// The maximum allowed number of bytes in flight.
|
||||
congestionWindow congestion.ByteCount
|
||||
|
||||
// The initial value of the |congestion_window_|.
|
||||
initialCongestionWindow congestion.ByteCount
|
||||
|
||||
// The largest value the |congestion_window_| can achieve.
|
||||
maxCongestionWindow congestion.ByteCount
|
||||
|
||||
// The smallest value the |congestion_window_| can achieve.
|
||||
minCongestionWindow congestion.ByteCount
|
||||
|
||||
// The pacing gain applied during the STARTUP phase.
|
||||
highGain float64
|
||||
|
||||
// The CWND gain applied during the STARTUP phase.
|
||||
highCwndGain float64
|
||||
|
||||
// The pacing gain applied during the DRAIN phase.
|
||||
drainGain float64
|
||||
|
||||
// The current pacing rate of the connection.
|
||||
pacingRate Bandwidth
|
||||
|
||||
// The gain currently applied to the pacing rate.
|
||||
pacingGain float64
|
||||
// The gain currently applied to the congestion window.
|
||||
congestionWindowGain float64
|
||||
|
||||
// The gain used for the congestion window during PROBE_BW. Latched from
|
||||
// quic_bbr_cwnd_gain flag.
|
||||
congestionWindowGainConstant float64
|
||||
// The number of RTTs to stay in STARTUP mode. Defaults to 3.
|
||||
numStartupRtts int64
|
||||
|
||||
// Number of round-trips in PROBE_BW mode, used for determining the current
|
||||
// pacing gain cycle.
|
||||
cycleCurrentOffset int
|
||||
// The time at which the last pacing gain cycle was started.
|
||||
lastCycleStart time.Time
|
||||
|
||||
// Indicates whether the connection has reached the full bandwidth mode.
|
||||
isAtFullBandwidth bool
|
||||
// Number of rounds during which there was no significant bandwidth increase.
|
||||
roundsWithoutBandwidthGain int64
|
||||
// The bandwidth compared to which the increase is measured.
|
||||
bandwidthAtLastRound Bandwidth
|
||||
|
||||
// Set to true upon exiting quiescence.
|
||||
exitingQuiescence bool
|
||||
|
||||
// Time at which PROBE_RTT has to be exited. Setting it to zero indicates
|
||||
// that the time is yet unknown as the number of packets in flight has not
|
||||
// reached the required value.
|
||||
exitProbeRttAt time.Time
|
||||
// Indicates whether a round-trip has passed since PROBE_RTT became active.
|
||||
probeRttRoundPassed bool
|
||||
|
||||
// Indicates whether the most recent bandwidth sample was marked as
|
||||
// app-limited.
|
||||
lastSampleIsAppLimited bool
|
||||
// Indicates whether any non app-limited samples have been recorded.
|
||||
hasNoAppLimitedSample bool
|
||||
|
||||
// Current state of recovery.
|
||||
recoveryState bbrRecoveryState
|
||||
// Receiving acknowledgement of a packet after |end_recovery_at_| will cause
|
||||
// BBR to exit the recovery mode. A value above zero indicates at least one
|
||||
// loss has been detected, so it must not be set back to zero.
|
||||
endRecoveryAt congestion.PacketNumber
|
||||
// A window used to limit the number of bytes in flight during loss recovery.
|
||||
recoveryWindow congestion.ByteCount
|
||||
// If true, consider all samples in recovery app-limited.
|
||||
isAppLimitedRecovery bool // not used
|
||||
|
||||
// When true, pace at 1.5x and disable packet conservation in STARTUP.
|
||||
slowerStartup bool // not used
|
||||
// When true, disables packet conservation in STARTUP.
|
||||
rateBasedStartup bool // not used
|
||||
|
||||
// When true, add the most recent ack aggregation measurement during STARTUP.
|
||||
enableAckAggregationDuringStartup bool
|
||||
// When true, expire the windowed ack aggregation values in STARTUP when
|
||||
// bandwidth increases more than 25%.
|
||||
expireAckAggregationInStartup bool
|
||||
|
||||
// If true, will not exit low gain mode until bytes_in_flight drops below BDP
|
||||
// or it's time for high gain mode.
|
||||
drainToTarget bool
|
||||
|
||||
// If true, slow down pacing rate in STARTUP when overshooting is detected.
|
||||
detectOvershooting bool
|
||||
// Bytes lost while detect_overshooting_ is true.
|
||||
bytesLostWhileDetectingOvershooting congestion.ByteCount
|
||||
// Slow down pacing rate if
|
||||
// bytes_lost_while_detecting_overshooting_ *
|
||||
// bytes_lost_multiplier_while_detecting_overshooting_ > IW.
|
||||
bytesLostMultiplierWhileDetectingOvershooting uint8
|
||||
// When overshooting is detected, do not drop pacing_rate_ below this value /
|
||||
// min_rtt.
|
||||
cwndToCalculateMinPacingRate congestion.ByteCount
|
||||
|
||||
// Max congestion window when adjusting network parameters.
|
||||
maxCongestionWindowWithNetworkParametersAdjusted congestion.ByteCount // not used
|
||||
|
||||
// Params.
|
||||
maxDatagramSize congestion.ByteCount
|
||||
// Recorded on packet sent. equivalent |unacked_packets_->bytes_in_flight()|
|
||||
bytesInFlight congestion.ByteCount
|
||||
}
|
||||
|
||||
var _ congestion.CongestionControl = &bbrSender{}
|
||||
|
||||
func NewBbrSender(
|
||||
clock Clock,
|
||||
initialMaxDatagramSize congestion.ByteCount,
|
||||
) *bbrSender {
|
||||
return newBbrSender(
|
||||
clock,
|
||||
initialMaxDatagramSize,
|
||||
initialCongestionWindowPackets*initialMaxDatagramSize,
|
||||
congestion.MaxCongestionWindowPackets*initialMaxDatagramSize,
|
||||
)
|
||||
}
|
||||
|
||||
func newBbrSender(
|
||||
clock Clock,
|
||||
initialMaxDatagramSize,
|
||||
initialCongestionWindow,
|
||||
initialMaxCongestionWindow congestion.ByteCount,
|
||||
) *bbrSender {
|
||||
b := &bbrSender{
|
||||
clock: clock,
|
||||
mode: bbrModeStartup,
|
||||
sampler: newBandwidthSampler(roundTripCount(bandwidthWindowSize)),
|
||||
lastSentPacket: invalidPacketNumber,
|
||||
currentRoundTripEnd: invalidPacketNumber,
|
||||
maxBandwidth: NewWindowedFilter(roundTripCount(bandwidthWindowSize), MaxFilter[Bandwidth]),
|
||||
congestionWindow: initialCongestionWindow,
|
||||
initialCongestionWindow: initialCongestionWindow,
|
||||
maxCongestionWindow: initialMaxCongestionWindow,
|
||||
minCongestionWindow: defaultMinimumCongestionWindow,
|
||||
highGain: defaultHighGain,
|
||||
highCwndGain: defaultHighGain,
|
||||
drainGain: 1.0 / defaultHighGain,
|
||||
pacingGain: 1.0,
|
||||
congestionWindowGain: 1.0,
|
||||
congestionWindowGainConstant: 2.0,
|
||||
numStartupRtts: roundTripsWithoutGrowthBeforeExitingStartup,
|
||||
recoveryState: bbrRecoveryStateNotInRecovery,
|
||||
endRecoveryAt: invalidPacketNumber,
|
||||
recoveryWindow: initialMaxCongestionWindow,
|
||||
bytesLostMultiplierWhileDetectingOvershooting: 2,
|
||||
cwndToCalculateMinPacingRate: initialCongestionWindow,
|
||||
maxCongestionWindowWithNetworkParametersAdjusted: initialMaxCongestionWindow,
|
||||
maxDatagramSize: initialMaxDatagramSize,
|
||||
}
|
||||
b.pacer = common.NewPacer(b.bandwidthForPacer)
|
||||
|
||||
/*
|
||||
if b.tracer != nil {
|
||||
b.lastState = logging.CongestionStateStartup
|
||||
b.tracer.UpdatedCongestionState(logging.CongestionStateStartup)
|
||||
}
|
||||
*/
|
||||
|
||||
b.enterStartupMode(b.clock.Now())
|
||||
b.setHighCwndGain(derivedHighCWNDGain)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *bbrSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
|
||||
b.rttStats = provider
|
||||
}
|
||||
|
||||
// TimeUntilSend implements the SendAlgorithm interface.
|
||||
func (b *bbrSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
|
||||
return b.pacer.TimeUntilSend()
|
||||
}
|
||||
|
||||
// HasPacingBudget implements the SendAlgorithm interface.
|
||||
func (b *bbrSender) HasPacingBudget(now time.Time) bool {
|
||||
return b.pacer.Budget(now) >= b.maxDatagramSize
|
||||
}
|
||||
|
||||
// OnPacketSent implements the SendAlgorithm interface.
|
||||
func (b *bbrSender) OnPacketSent(
|
||||
sentTime time.Time,
|
||||
bytesInFlight congestion.ByteCount,
|
||||
packetNumber congestion.PacketNumber,
|
||||
bytes congestion.ByteCount,
|
||||
isRetransmittable bool,
|
||||
) {
|
||||
b.pacer.SentPacket(sentTime, bytes)
|
||||
|
||||
b.lastSentPacket = packetNumber
|
||||
b.bytesInFlight = bytesInFlight
|
||||
|
||||
if bytesInFlight == 0 {
|
||||
b.exitingQuiescence = true
|
||||
}
|
||||
|
||||
b.sampler.OnPacketSent(sentTime, packetNumber, bytes, bytesInFlight, isRetransmittable)
|
||||
}
|
||||
|
||||
// CanSend implements the SendAlgorithm interface.
|
||||
func (b *bbrSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
||||
return bytesInFlight < b.GetCongestionWindow()
|
||||
}
|
||||
|
||||
// MaybeExitSlowStart implements the SendAlgorithm interface.
|
||||
func (b *bbrSender) MaybeExitSlowStart() {
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
// OnPacketAcked implements the SendAlgorithm interface.
|
||||
func (b *bbrSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes, priorInFlight congestion.ByteCount, eventTime time.Time) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
// OnPacketLost implements the SendAlgorithm interface.
|
||||
func (b *bbrSender) OnPacketLost(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
// OnRetransmissionTimeout implements the SendAlgorithm interface.
|
||||
func (b *bbrSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
// SetMaxDatagramSize implements the SendAlgorithm interface.
|
||||
func (b *bbrSender) SetMaxDatagramSize(s congestion.ByteCount) {
|
||||
if s < b.maxDatagramSize {
|
||||
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", b.maxDatagramSize, s))
|
||||
}
|
||||
cwndIsMinCwnd := b.congestionWindow == b.minCongestionWindow
|
||||
b.maxDatagramSize = s
|
||||
if cwndIsMinCwnd {
|
||||
b.congestionWindow = b.minCongestionWindow
|
||||
}
|
||||
b.pacer.SetMaxDatagramSize(s)
|
||||
}
|
||||
|
||||
// InSlowStart implements the SendAlgorithmWithDebugInfos interface.
|
||||
func (b *bbrSender) InSlowStart() bool {
|
||||
return b.mode == bbrModeStartup
|
||||
}
|
||||
|
||||
// InRecovery implements the SendAlgorithmWithDebugInfos interface.
|
||||
func (b *bbrSender) InRecovery() bool {
|
||||
return b.recoveryState != bbrRecoveryStateNotInRecovery
|
||||
}
|
||||
|
||||
// GetCongestionWindow implements the SendAlgorithmWithDebugInfos interface.
|
||||
func (b *bbrSender) GetCongestionWindow() congestion.ByteCount {
|
||||
if b.mode == bbrModeProbeRtt {
|
||||
return b.probeRttCongestionWindow()
|
||||
}
|
||||
|
||||
if b.InRecovery() {
|
||||
return min(b.congestionWindow, b.recoveryWindow)
|
||||
}
|
||||
|
||||
return b.congestionWindow
|
||||
}
|
||||
|
||||
func (b *bbrSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) {
|
||||
totalBytesAckedBefore := b.sampler.TotalBytesAcked()
|
||||
totalBytesLostBefore := b.sampler.TotalBytesLost()
|
||||
|
||||
var isRoundStart, minRttExpired bool
|
||||
var excessAcked, bytesLost congestion.ByteCount
|
||||
|
||||
// The send state of the largest packet in acked_packets, unless it is
|
||||
// empty. If acked_packets is empty, it's the send state of the largest
|
||||
// packet in lost_packets.
|
||||
var lastPacketSendState sendTimeState
|
||||
|
||||
b.maybeApplimited(priorInFlight)
|
||||
|
||||
// Update bytesInFlight
|
||||
b.bytesInFlight = priorInFlight
|
||||
for _, p := range ackedPackets {
|
||||
b.bytesInFlight -= p.BytesAcked
|
||||
}
|
||||
for _, p := range lostPackets {
|
||||
b.bytesInFlight -= p.BytesLost
|
||||
}
|
||||
|
||||
if len(ackedPackets) != 0 {
|
||||
lastAckedPacket := ackedPackets[len(ackedPackets)-1].PacketNumber
|
||||
isRoundStart = b.updateRoundTripCounter(lastAckedPacket)
|
||||
b.updateRecoveryState(lastAckedPacket, len(lostPackets) != 0, isRoundStart)
|
||||
}
|
||||
|
||||
sample := b.sampler.OnCongestionEvent(eventTime,
|
||||
ackedPackets, lostPackets, b.maxBandwidth.GetBest(), infBandwidth, b.roundTripCount)
|
||||
if sample.lastPacketSendState.isValid {
|
||||
b.lastSampleIsAppLimited = sample.lastPacketSendState.isAppLimited
|
||||
b.hasNoAppLimitedSample = b.hasNoAppLimitedSample || !b.lastSampleIsAppLimited
|
||||
}
|
||||
// Avoid updating |max_bandwidth_| if a) this is a loss-only event, or b) all
|
||||
// packets in |acked_packets| did not generate valid samples. (e.g. ack of
|
||||
// ack-only packets). In both cases, sampler_.total_bytes_acked() will not
|
||||
// change.
|
||||
if totalBytesAckedBefore != b.sampler.TotalBytesAcked() {
|
||||
if !sample.sampleIsAppLimited || sample.sampleMaxBandwidth > b.maxBandwidth.GetBest() {
|
||||
b.maxBandwidth.Update(sample.sampleMaxBandwidth, b.roundTripCount)
|
||||
}
|
||||
}
|
||||
|
||||
if sample.sampleRtt != infRTT {
|
||||
minRttExpired = b.maybeUpdateMinRtt(eventTime, sample.sampleRtt)
|
||||
}
|
||||
bytesLost = b.sampler.TotalBytesLost() - totalBytesLostBefore
|
||||
|
||||
excessAcked = sample.extraAcked
|
||||
lastPacketSendState = sample.lastPacketSendState
|
||||
|
||||
if len(lostPackets) != 0 {
|
||||
b.numLossEventsInRound++
|
||||
b.bytesLostInRound += bytesLost
|
||||
}
|
||||
|
||||
// Handle logic specific to PROBE_BW mode.
|
||||
if b.mode == bbrModeProbeBw {
|
||||
b.updateGainCyclePhase(eventTime, priorInFlight, len(lostPackets) != 0)
|
||||
}
|
||||
|
||||
// Handle logic specific to STARTUP and DRAIN modes.
|
||||
if isRoundStart && !b.isAtFullBandwidth {
|
||||
b.checkIfFullBandwidthReached(&lastPacketSendState)
|
||||
}
|
||||
|
||||
b.maybeExitStartupOrDrain(eventTime)
|
||||
|
||||
// Handle logic specific to PROBE_RTT.
|
||||
b.maybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired)
|
||||
|
||||
// Calculate number of packets acked and lost.
|
||||
bytesAcked := b.sampler.TotalBytesAcked() - totalBytesAckedBefore
|
||||
|
||||
// After the model is updated, recalculate the pacing rate and congestion
|
||||
// window.
|
||||
b.calculatePacingRate(bytesLost)
|
||||
b.calculateCongestionWindow(bytesAcked, excessAcked)
|
||||
b.calculateRecoveryWindow(bytesAcked, bytesLost)
|
||||
|
||||
// Cleanup internal state.
|
||||
// This is where we clean up obsolete (acked or lost) packets from the bandwidth sampler.
|
||||
// The "least unacked" should actually be FirstOutstanding, but since we are not passing
|
||||
// that through OnCongestionEventEx, we will only do an estimate using acked/lost packets
|
||||
// for now. Because of fast retransmission, they should differ by no more than 2 packets.
|
||||
// (this is controlled by packetThreshold in quic-go's sentPacketHandler)
|
||||
var leastUnacked congestion.PacketNumber
|
||||
if len(ackedPackets) != 0 {
|
||||
leastUnacked = ackedPackets[len(ackedPackets)-1].PacketNumber - 2
|
||||
} else {
|
||||
leastUnacked = lostPackets[len(lostPackets)-1].PacketNumber + 1
|
||||
}
|
||||
b.sampler.RemoveObsoletePackets(leastUnacked)
|
||||
|
||||
if isRoundStart {
|
||||
b.numLossEventsInRound = 0
|
||||
b.bytesLostInRound = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bbrSender) PacingRate() Bandwidth {
|
||||
if b.pacingRate == 0 {
|
||||
return Bandwidth(b.highGain * float64(
|
||||
BandwidthFromDelta(b.initialCongestionWindow, b.getMinRtt())))
|
||||
}
|
||||
|
||||
return b.pacingRate
|
||||
}
|
||||
|
||||
func (b *bbrSender) hasGoodBandwidthEstimateForResumption() bool {
|
||||
return b.hasNonAppLimitedSample()
|
||||
}
|
||||
|
||||
func (b *bbrSender) hasNonAppLimitedSample() bool {
|
||||
return b.hasNoAppLimitedSample
|
||||
}
|
||||
|
||||
// Sets the pacing gain used in STARTUP. Must be greater than 1.
|
||||
func (b *bbrSender) setHighGain(highGain float64) {
|
||||
b.highGain = highGain
|
||||
if b.mode == bbrModeStartup {
|
||||
b.pacingGain = highGain
|
||||
}
|
||||
}
|
||||
|
||||
// Sets the CWND gain used in STARTUP. Must be greater than 1.
|
||||
func (b *bbrSender) setHighCwndGain(highCwndGain float64) {
|
||||
b.highCwndGain = highCwndGain
|
||||
if b.mode == bbrModeStartup {
|
||||
b.congestionWindowGain = highCwndGain
|
||||
}
|
||||
}
|
||||
|
||||
// Sets the gain used in DRAIN. Must be less than 1.
|
||||
func (b *bbrSender) setDrainGain(drainGain float64) {
|
||||
b.drainGain = drainGain
|
||||
}
|
||||
|
||||
// What's the current estimated bandwidth in bytes per second.
|
||||
func (b *bbrSender) bandwidthEstimate() Bandwidth {
|
||||
return b.maxBandwidth.GetBest()
|
||||
}
|
||||
|
||||
func (b *bbrSender) bandwidthForPacer() congestion.ByteCount {
|
||||
bps := congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond))
|
||||
if bps < minBps {
|
||||
// We need to make sure that the bandwidth value for pacer is never zero,
|
||||
// otherwise it will go into an edge case where HasPacingBudget = false
|
||||
// but TimeUntilSend is before, causing the quic-go send loop to go crazy and get stuck.
|
||||
return minBps
|
||||
}
|
||||
return bps
|
||||
}
|
||||
|
||||
// Returns the current estimate of the RTT of the connection. Outside of the
|
||||
// edge cases, this is minimum RTT.
|
||||
func (b *bbrSender) getMinRtt() time.Duration {
|
||||
if b.minRtt != 0 {
|
||||
return b.minRtt
|
||||
}
|
||||
// min_rtt could be available if the handshake packet gets neutered then
|
||||
// gets acknowledged. This could only happen for QUIC crypto where we do not
|
||||
// drop keys.
|
||||
minRtt := b.rttStats.MinRTT()
|
||||
if minRtt == 0 {
|
||||
return 100 * time.Millisecond
|
||||
} else {
|
||||
return minRtt
|
||||
}
|
||||
}
|
||||
|
||||
// Computes the target congestion window using the specified gain.
|
||||
func (b *bbrSender) getTargetCongestionWindow(gain float64) congestion.ByteCount {
|
||||
bdp := bdpFromRttAndBandwidth(b.getMinRtt(), b.bandwidthEstimate())
|
||||
congestionWindow := congestion.ByteCount(gain * float64(bdp))
|
||||
|
||||
// BDP estimate will be zero if no bandwidth samples are available yet.
|
||||
if congestionWindow == 0 {
|
||||
congestionWindow = congestion.ByteCount(gain * float64(b.initialCongestionWindow))
|
||||
}
|
||||
|
||||
return max(congestionWindow, b.minCongestionWindow)
|
||||
}
|
||||
|
||||
// The target congestion window during PROBE_RTT.
|
||||
func (b *bbrSender) probeRttCongestionWindow() congestion.ByteCount {
|
||||
return b.minCongestionWindow
|
||||
}
|
||||
|
||||
func (b *bbrSender) maybeUpdateMinRtt(now time.Time, sampleMinRtt time.Duration) bool {
|
||||
// Do not expire min_rtt if none was ever available.
|
||||
minRttExpired := b.minRtt != 0 && now.After(b.minRttTimestamp.Add(minRttExpiry))
|
||||
if minRttExpired || sampleMinRtt < b.minRtt || b.minRtt == 0 {
|
||||
b.minRtt = sampleMinRtt
|
||||
b.minRttTimestamp = now
|
||||
}
|
||||
|
||||
return minRttExpired
|
||||
}
|
||||
|
||||
// Enters the STARTUP mode.
|
||||
func (b *bbrSender) enterStartupMode(now time.Time) {
|
||||
b.mode = bbrModeStartup
|
||||
// b.maybeTraceStateChange(logging.CongestionStateStartup)
|
||||
b.pacingGain = b.highGain
|
||||
b.congestionWindowGain = b.highCwndGain
|
||||
}
|
||||
|
||||
// Enters the PROBE_BW mode.
|
||||
func (b *bbrSender) enterProbeBandwidthMode(now time.Time) {
|
||||
b.mode = bbrModeProbeBw
|
||||
// b.maybeTraceStateChange(logging.CongestionStateProbeBw)
|
||||
b.congestionWindowGain = b.congestionWindowGainConstant
|
||||
|
||||
// Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is
|
||||
// excluded because in that case increased gain and decreased gain would not
|
||||
// follow each other.
|
||||
b.cycleCurrentOffset = int(rand.Int31n(congestion.PacketsPerConnectionID)) % (gainCycleLength - 1)
|
||||
if b.cycleCurrentOffset >= 1 {
|
||||
b.cycleCurrentOffset += 1
|
||||
}
|
||||
|
||||
b.lastCycleStart = now
|
||||
b.pacingGain = pacingGain[b.cycleCurrentOffset]
|
||||
}
|
||||
|
||||
// Updates the round-trip counter if a round-trip has passed. Returns true if
|
||||
// the counter has been advanced.
|
||||
func (b *bbrSender) updateRoundTripCounter(lastAckedPacket congestion.PacketNumber) bool {
|
||||
if b.currentRoundTripEnd == invalidPacketNumber || lastAckedPacket > b.currentRoundTripEnd {
|
||||
b.roundTripCount++
|
||||
b.currentRoundTripEnd = b.lastSentPacket
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Updates the current gain used in PROBE_BW mode.
|
||||
func (b *bbrSender) updateGainCyclePhase(now time.Time, priorInFlight congestion.ByteCount, hasLosses bool) {
|
||||
// In most cases, the cycle is advanced after an RTT passes.
|
||||
shouldAdvanceGainCycling := now.After(b.lastCycleStart.Add(b.getMinRtt()))
|
||||
// If the pacing gain is above 1.0, the connection is trying to probe the
|
||||
// bandwidth by increasing the number of bytes in flight to at least
|
||||
// pacing_gain * BDP. Make sure that it actually reaches the target, as long
|
||||
// as there are no losses suggesting that the buffers are not able to hold
|
||||
// that much.
|
||||
if b.pacingGain > 1.0 && !hasLosses && priorInFlight < b.getTargetCongestionWindow(b.pacingGain) {
|
||||
shouldAdvanceGainCycling = false
|
||||
}
|
||||
|
||||
// If pacing gain is below 1.0, the connection is trying to drain the extra
|
||||
// queue which could have been incurred by probing prior to it. If the number
|
||||
// of bytes in flight falls down to the estimated BDP value earlier, conclude
|
||||
// that the queue has been successfully drained and exit this cycle early.
|
||||
if b.pacingGain < 1.0 && b.bytesInFlight <= b.getTargetCongestionWindow(1) {
|
||||
shouldAdvanceGainCycling = true
|
||||
}
|
||||
|
||||
if shouldAdvanceGainCycling {
|
||||
b.cycleCurrentOffset = (b.cycleCurrentOffset + 1) % gainCycleLength
|
||||
b.lastCycleStart = now
|
||||
// Stay in low gain mode until the target BDP is hit.
|
||||
// Low gain mode will be exited immediately when the target BDP is achieved.
|
||||
if b.drainToTarget && b.pacingGain < 1 &&
|
||||
pacingGain[b.cycleCurrentOffset] == 1 &&
|
||||
b.bytesInFlight > b.getTargetCongestionWindow(1) {
|
||||
return
|
||||
}
|
||||
b.pacingGain = pacingGain[b.cycleCurrentOffset]
|
||||
}
|
||||
}
|
||||
|
||||
// Tracks for how many round-trips the bandwidth has not increased
|
||||
// significantly.
|
||||
func (b *bbrSender) checkIfFullBandwidthReached(lastPacketSendState *sendTimeState) {
|
||||
if b.lastSampleIsAppLimited {
|
||||
return
|
||||
}
|
||||
|
||||
target := Bandwidth(float64(b.bandwidthAtLastRound) * startupGrowthTarget)
|
||||
if b.bandwidthEstimate() >= target {
|
||||
b.bandwidthAtLastRound = b.bandwidthEstimate()
|
||||
b.roundsWithoutBandwidthGain = 0
|
||||
if b.expireAckAggregationInStartup {
|
||||
// Expire old excess delivery measurements now that bandwidth increased.
|
||||
b.sampler.ResetMaxAckHeightTracker(0, b.roundTripCount)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
b.roundsWithoutBandwidthGain++
|
||||
if b.roundsWithoutBandwidthGain >= b.numStartupRtts ||
|
||||
b.shouldExitStartupDueToLoss(lastPacketSendState) {
|
||||
b.isAtFullBandwidth = true
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bbrSender) maybeApplimited(bytesInFlight congestion.ByteCount) {
|
||||
congestionWindow := b.GetCongestionWindow()
|
||||
if bytesInFlight >= congestionWindow {
|
||||
return
|
||||
}
|
||||
availableBytes := congestionWindow - bytesInFlight
|
||||
drainLimited := b.mode == bbrModeDrain && bytesInFlight > congestionWindow/2
|
||||
if !drainLimited || availableBytes > maxBbrBurstPackets*b.maxDatagramSize {
|
||||
b.sampler.OnAppLimited()
|
||||
}
|
||||
}
|
||||
|
||||
// Transitions from STARTUP to DRAIN and from DRAIN to PROBE_BW if
|
||||
// appropriate.
|
||||
func (b *bbrSender) maybeExitStartupOrDrain(now time.Time) {
|
||||
if b.mode == bbrModeStartup && b.isAtFullBandwidth {
|
||||
b.mode = bbrModeDrain
|
||||
// b.maybeTraceStateChange(logging.CongestionStateDrain)
|
||||
b.pacingGain = b.drainGain
|
||||
b.congestionWindowGain = b.highCwndGain
|
||||
}
|
||||
if b.mode == bbrModeDrain && b.bytesInFlight <= b.getTargetCongestionWindow(1) {
|
||||
b.enterProbeBandwidthMode(now)
|
||||
}
|
||||
}
|
||||
|
||||
// Decides whether to enter or exit PROBE_RTT.
|
||||
func (b *bbrSender) maybeEnterOrExitProbeRtt(now time.Time, isRoundStart, minRttExpired bool) {
|
||||
if minRttExpired && !b.exitingQuiescence && b.mode != bbrModeProbeRtt {
|
||||
b.mode = bbrModeProbeRtt
|
||||
// b.maybeTraceStateChange(logging.CongestionStateProbRtt)
|
||||
b.pacingGain = 1.0
|
||||
// Do not decide on the time to exit PROBE_RTT until the |bytes_in_flight|
|
||||
// is at the target small value.
|
||||
b.exitProbeRttAt = time.Time{}
|
||||
}
|
||||
|
||||
if b.mode == bbrModeProbeRtt {
|
||||
b.sampler.OnAppLimited()
|
||||
// b.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
|
||||
|
||||
if b.exitProbeRttAt.IsZero() {
|
||||
// If the window has reached the appropriate size, schedule exiting
|
||||
// PROBE_RTT. The CWND during PROBE_RTT is kMinimumCongestionWindow, but
|
||||
// we allow an extra packet since QUIC checks CWND before sending a
|
||||
// packet.
|
||||
if b.bytesInFlight < b.probeRttCongestionWindow()+congestion.MaxPacketBufferSize {
|
||||
b.exitProbeRttAt = now.Add(probeRttTime)
|
||||
b.probeRttRoundPassed = false
|
||||
}
|
||||
} else {
|
||||
if isRoundStart {
|
||||
b.probeRttRoundPassed = true
|
||||
}
|
||||
if now.Sub(b.exitProbeRttAt) >= 0 && b.probeRttRoundPassed {
|
||||
b.minRttTimestamp = now
|
||||
if !b.isAtFullBandwidth {
|
||||
b.enterStartupMode(now)
|
||||
} else {
|
||||
b.enterProbeBandwidthMode(now)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b.exitingQuiescence = false
|
||||
}
|
||||
|
||||
// Determines whether BBR needs to enter, exit or advance state of the
|
||||
// recovery.
|
||||
func (b *bbrSender) updateRecoveryState(lastAckedPacket congestion.PacketNumber, hasLosses, isRoundStart bool) {
|
||||
// Disable recovery in startup, if loss-based exit is enabled.
|
||||
if !b.isAtFullBandwidth {
|
||||
return
|
||||
}
|
||||
|
||||
// Exit recovery when there are no losses for a round.
|
||||
if hasLosses {
|
||||
b.endRecoveryAt = b.lastSentPacket
|
||||
}
|
||||
|
||||
switch b.recoveryState {
|
||||
case bbrRecoveryStateNotInRecovery:
|
||||
if hasLosses {
|
||||
b.recoveryState = bbrRecoveryStateConservation
|
||||
// This will cause the |recovery_window_| to be set to the correct
|
||||
// value in CalculateRecoveryWindow().
|
||||
b.recoveryWindow = 0
|
||||
// Since the conservation phase is meant to be lasting for a whole
|
||||
// round, extend the current round as if it were started right now.
|
||||
b.currentRoundTripEnd = b.lastSentPacket
|
||||
}
|
||||
case bbrRecoveryStateConservation:
|
||||
if isRoundStart {
|
||||
b.recoveryState = bbrRecoveryStateGrowth
|
||||
}
|
||||
fallthrough
|
||||
case bbrRecoveryStateGrowth:
|
||||
// Exit recovery if appropriate.
|
||||
if !hasLosses && lastAckedPacket > b.endRecoveryAt {
|
||||
b.recoveryState = bbrRecoveryStateNotInRecovery
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determines the appropriate pacing rate for the connection.
|
||||
func (b *bbrSender) calculatePacingRate(bytesLost congestion.ByteCount) {
|
||||
if b.bandwidthEstimate() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
targetRate := Bandwidth(b.pacingGain * float64(b.bandwidthEstimate()))
|
||||
if b.isAtFullBandwidth {
|
||||
b.pacingRate = targetRate
|
||||
return
|
||||
}
|
||||
|
||||
// Pace at the rate of initial_window / RTT as soon as RTT measurements are
|
||||
// available.
|
||||
if b.pacingRate == 0 && b.rttStats.MinRTT() != 0 {
|
||||
b.pacingRate = BandwidthFromDelta(b.initialCongestionWindow, b.rttStats.MinRTT())
|
||||
return
|
||||
}
|
||||
|
||||
if b.detectOvershooting {
|
||||
b.bytesLostWhileDetectingOvershooting += bytesLost
|
||||
// Check for overshooting with network parameters adjusted when pacing rate
|
||||
// > target_rate and loss has been detected.
|
||||
if b.pacingRate > targetRate && b.bytesLostWhileDetectingOvershooting > 0 {
|
||||
if b.hasNoAppLimitedSample ||
|
||||
b.bytesLostWhileDetectingOvershooting*congestion.ByteCount(b.bytesLostMultiplierWhileDetectingOvershooting) > b.initialCongestionWindow {
|
||||
// We are fairly sure overshoot happens if 1) there is at least one
|
||||
// non app-limited bw sample or 2) half of IW gets lost. Slow pacing
|
||||
// rate.
|
||||
b.pacingRate = max(targetRate, BandwidthFromDelta(b.cwndToCalculateMinPacingRate, b.rttStats.MinRTT()))
|
||||
b.bytesLostWhileDetectingOvershooting = 0
|
||||
b.detectOvershooting = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do not decrease the pacing rate during startup.
|
||||
b.pacingRate = max(b.pacingRate, targetRate)
|
||||
}
|
||||
|
||||
// Determines the appropriate congestion window for the connection.
|
||||
func (b *bbrSender) calculateCongestionWindow(bytesAcked, excessAcked congestion.ByteCount) {
|
||||
if b.mode == bbrModeProbeRtt {
|
||||
return
|
||||
}
|
||||
|
||||
targetWindow := b.getTargetCongestionWindow(b.congestionWindowGain)
|
||||
if b.isAtFullBandwidth {
|
||||
// Add the max recently measured ack aggregation to CWND.
|
||||
targetWindow += b.sampler.MaxAckHeight()
|
||||
} else if b.enableAckAggregationDuringStartup {
|
||||
// Add the most recent excess acked. Because CWND never decreases in
|
||||
// STARTUP, this will automatically create a very localized max filter.
|
||||
targetWindow += excessAcked
|
||||
}
|
||||
|
||||
// Instead of immediately setting the target CWND as the new one, BBR grows
|
||||
// the CWND towards |target_window| by only increasing it |bytes_acked| at a
|
||||
// time.
|
||||
if b.isAtFullBandwidth {
|
||||
b.congestionWindow = min(targetWindow, b.congestionWindow+bytesAcked)
|
||||
} else if b.congestionWindow < targetWindow ||
|
||||
b.sampler.TotalBytesAcked() < b.initialCongestionWindow {
|
||||
// If the connection is not yet out of startup phase, do not decrease the
|
||||
// window.
|
||||
b.congestionWindow += bytesAcked
|
||||
}
|
||||
|
||||
// Enforce the limits on the congestion window.
|
||||
b.congestionWindow = max(b.congestionWindow, b.minCongestionWindow)
|
||||
b.congestionWindow = min(b.congestionWindow, b.maxCongestionWindow)
|
||||
}
|
||||
|
||||
// Determines the appropriate window that constrains the in-flight during recovery.
|
||||
func (b *bbrSender) calculateRecoveryWindow(bytesAcked, bytesLost congestion.ByteCount) {
|
||||
if b.recoveryState == bbrRecoveryStateNotInRecovery {
|
||||
return
|
||||
}
|
||||
|
||||
// Set up the initial recovery window.
|
||||
if b.recoveryWindow == 0 {
|
||||
b.recoveryWindow = b.bytesInFlight + bytesAcked
|
||||
b.recoveryWindow = max(b.minCongestionWindow, b.recoveryWindow)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove losses from the recovery window, while accounting for a potential
|
||||
// integer underflow.
|
||||
if b.recoveryWindow >= bytesLost {
|
||||
b.recoveryWindow = b.recoveryWindow - bytesLost
|
||||
} else {
|
||||
b.recoveryWindow = b.maxDatagramSize
|
||||
}
|
||||
|
||||
// In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH,
|
||||
// release additional |bytes_acked| to achieve a slow-start-like behavior.
|
||||
if b.recoveryState == bbrRecoveryStateGrowth {
|
||||
b.recoveryWindow += bytesAcked
|
||||
}
|
||||
|
||||
// Always allow sending at least |bytes_acked| in response.
|
||||
b.recoveryWindow = max(b.recoveryWindow, b.bytesInFlight+bytesAcked)
|
||||
b.recoveryWindow = max(b.minCongestionWindow, b.recoveryWindow)
|
||||
}
|
||||
|
||||
// Return whether we should exit STARTUP due to excessive loss.
|
||||
func (b *bbrSender) shouldExitStartupDueToLoss(lastPacketSendState *sendTimeState) bool {
|
||||
if b.numLossEventsInRound < defaultStartupFullLossCount || !lastPacketSendState.isValid {
|
||||
return false
|
||||
}
|
||||
|
||||
inflightAtSend := lastPacketSendState.bytesInFlight
|
||||
|
||||
if inflightAtSend > 0 && b.bytesLostInRound > 0 {
|
||||
if b.bytesLostInRound > congestion.ByteCount(float64(inflightAtSend)*quicBbr2DefaultLossThreshold) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func bdpFromRttAndBandwidth(rtt time.Duration, bandwidth Bandwidth) congestion.ByteCount {
|
||||
return congestion.ByteCount(rtt) * congestion.ByteCount(bandwidth) / congestion.ByteCount(BytesPerSecond) / congestion.ByteCount(time.Second)
|
||||
}
|
||||
|
||||
func GetInitialPacketSize(addr net.Addr) congestion.ByteCount {
|
||||
// If this is not a UDP address, we don't know anything about the MTU.
|
||||
// Use the minimum size of an Initial packet as the max packet size.
|
||||
if udpAddr, ok := addr.(*net.UDPAddr); ok {
|
||||
if udpAddr.IP.To4() != nil {
|
||||
return congestion.InitialPacketSizeIPv4
|
||||
} else {
|
||||
return congestion.InitialPacketSizeIPv6
|
||||
}
|
||||
} else {
|
||||
return congestion.MinInitialPacketSize
|
||||
}
|
||||
}
|
18
transport/hysteria2/core/internal/congestion/bbr/clock.go
Normal file
18
transport/hysteria2/core/internal/congestion/bbr/clock.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package bbr
|
||||
|
||||
import "time"
|
||||
|
||||
// A Clock returns the current time
|
||||
type Clock interface {
|
||||
Now() time.Time
|
||||
}
|
||||
|
||||
// DefaultClock implements the Clock interface using the Go stdlib clock.
|
||||
type DefaultClock struct{}
|
||||
|
||||
var _ Clock = DefaultClock{}
|
||||
|
||||
// Now gets the current time
|
||||
func (DefaultClock) Now() time.Time {
|
||||
return time.Now()
|
||||
}
|
|
@ -0,0 +1,199 @@
|
|||
package bbr
|
||||
|
||||
import (
|
||||
"github.com/metacubex/quic-go/congestion"
|
||||
)
|
||||
|
||||
// packetNumberIndexedQueue is a queue of mostly continuous numbered entries
|
||||
// which supports the following operations:
|
||||
// - adding elements to the end of the queue, or at some point past the end
|
||||
// - removing elements in any order
|
||||
// - retrieving elements
|
||||
// If all elements are inserted in order, all of the operations above are
|
||||
// amortized O(1) time.
|
||||
//
|
||||
// Internally, the data structure is a deque where each element is marked as
|
||||
// present or not. The deque starts at the lowest present index. Whenever an
|
||||
// element is removed, it's marked as not present, and the front of the deque is
|
||||
// cleared of elements that are not present.
|
||||
//
|
||||
// The tail of the queue is not cleared due to the assumption of entries being
|
||||
// inserted in order, though removing all elements of the queue will return it
|
||||
// to its initial state.
|
||||
//
|
||||
// Note that this data structure is inherently hazardous, since an addition of
|
||||
// just two entries will cause it to consume all of the memory available.
|
||||
// Because of that, it is not a general-purpose container and should not be used
|
||||
// as one.
|
||||
|
||||
type entryWrapper[T any] struct {
|
||||
present bool
|
||||
entry T
|
||||
}
|
||||
|
||||
type packetNumberIndexedQueue[T any] struct {
|
||||
entries RingBuffer[entryWrapper[T]]
|
||||
numberOfPresentEntries int
|
||||
firstPacket congestion.PacketNumber
|
||||
}
|
||||
|
||||
func newPacketNumberIndexedQueue[T any](size int) *packetNumberIndexedQueue[T] {
|
||||
q := &packetNumberIndexedQueue[T]{
|
||||
firstPacket: invalidPacketNumber,
|
||||
}
|
||||
|
||||
q.entries.Init(size)
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// Emplace inserts data associated |packet_number| into (or past) the end of the
|
||||
// queue, filling up the missing intermediate entries as necessary. Returns
|
||||
// true if the element has been inserted successfully, false if it was already
|
||||
// in the queue or inserted out of order.
|
||||
func (p *packetNumberIndexedQueue[T]) Emplace(packetNumber congestion.PacketNumber, entry *T) bool {
|
||||
if packetNumber == invalidPacketNumber || entry == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if p.IsEmpty() {
|
||||
p.entries.PushBack(entryWrapper[T]{
|
||||
present: true,
|
||||
entry: *entry,
|
||||
})
|
||||
p.numberOfPresentEntries = 1
|
||||
p.firstPacket = packetNumber
|
||||
return true
|
||||
}
|
||||
|
||||
// Do not allow insertion out-of-order.
|
||||
if packetNumber <= p.LastPacket() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle potentially missing elements.
|
||||
offset := int(packetNumber - p.FirstPacket())
|
||||
if gap := offset - p.entries.Len(); gap > 0 {
|
||||
for i := 0; i < gap; i++ {
|
||||
p.entries.PushBack(entryWrapper[T]{})
|
||||
}
|
||||
}
|
||||
|
||||
p.entries.PushBack(entryWrapper[T]{
|
||||
present: true,
|
||||
entry: *entry,
|
||||
})
|
||||
p.numberOfPresentEntries++
|
||||
return true
|
||||
}
|
||||
|
||||
// GetEntry Retrieve the entry associated with the packet number. Returns the pointer
|
||||
// to the entry in case of success, or nullptr if the entry does not exist.
|
||||
func (p *packetNumberIndexedQueue[T]) GetEntry(packetNumber congestion.PacketNumber) *T {
|
||||
ew := p.getEntryWraper(packetNumber)
|
||||
if ew == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &ew.entry
|
||||
}
|
||||
|
||||
// Remove, Same as above, but if an entry is present in the queue, also call f(entry)
|
||||
// before removing it.
|
||||
func (p *packetNumberIndexedQueue[T]) Remove(packetNumber congestion.PacketNumber, f func(T)) bool {
|
||||
ew := p.getEntryWraper(packetNumber)
|
||||
if ew == nil {
|
||||
return false
|
||||
}
|
||||
if f != nil {
|
||||
f(ew.entry)
|
||||
}
|
||||
ew.present = false
|
||||
p.numberOfPresentEntries--
|
||||
|
||||
if packetNumber == p.FirstPacket() {
|
||||
p.clearup()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// RemoveUpTo, but not including |packet_number|.
|
||||
// Unused slots in the front are also removed, which means when the function
|
||||
// returns, |first_packet()| can be larger than |packet_number|.
|
||||
func (p *packetNumberIndexedQueue[T]) RemoveUpTo(packetNumber congestion.PacketNumber) {
|
||||
for !p.entries.Empty() &&
|
||||
p.firstPacket != invalidPacketNumber &&
|
||||
p.firstPacket < packetNumber {
|
||||
if p.entries.Front().present {
|
||||
p.numberOfPresentEntries--
|
||||
}
|
||||
p.entries.PopFront()
|
||||
p.firstPacket++
|
||||
}
|
||||
p.clearup()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// IsEmpty return if queue is empty.
|
||||
func (p *packetNumberIndexedQueue[T]) IsEmpty() bool {
|
||||
return p.numberOfPresentEntries == 0
|
||||
}
|
||||
|
||||
// NumberOfPresentEntries returns the number of entries in the queue.
|
||||
func (p *packetNumberIndexedQueue[T]) NumberOfPresentEntries() int {
|
||||
return p.numberOfPresentEntries
|
||||
}
|
||||
|
||||
// EntrySlotsUsed returns the number of entries allocated in the underlying deque. This is
|
||||
// proportional to the memory usage of the queue.
|
||||
func (p *packetNumberIndexedQueue[T]) EntrySlotsUsed() int {
|
||||
return p.entries.Len()
|
||||
}
|
||||
|
||||
// LastPacket returns packet number of the first entry in the queue.
|
||||
func (p *packetNumberIndexedQueue[T]) FirstPacket() (packetNumber congestion.PacketNumber) {
|
||||
return p.firstPacket
|
||||
}
|
||||
|
||||
// LastPacket returns packet number of the last entry ever inserted in the queue. Note that the
|
||||
// entry in question may have already been removed. Zero if the queue is
|
||||
// empty.
|
||||
func (p *packetNumberIndexedQueue[T]) LastPacket() (packetNumber congestion.PacketNumber) {
|
||||
if p.IsEmpty() {
|
||||
return invalidPacketNumber
|
||||
}
|
||||
|
||||
return p.firstPacket + congestion.PacketNumber(p.entries.Len()-1)
|
||||
}
|
||||
|
||||
func (p *packetNumberIndexedQueue[T]) clearup() {
|
||||
for !p.entries.Empty() && !p.entries.Front().present {
|
||||
p.entries.PopFront()
|
||||
p.firstPacket++
|
||||
}
|
||||
if p.entries.Empty() {
|
||||
p.firstPacket = invalidPacketNumber
|
||||
}
|
||||
}
|
||||
|
||||
func (p *packetNumberIndexedQueue[T]) getEntryWraper(packetNumber congestion.PacketNumber) *entryWrapper[T] {
|
||||
if packetNumber == invalidPacketNumber ||
|
||||
p.IsEmpty() ||
|
||||
packetNumber < p.firstPacket {
|
||||
return nil
|
||||
}
|
||||
|
||||
offset := int(packetNumber - p.firstPacket)
|
||||
if offset >= p.entries.Len() {
|
||||
return nil
|
||||
}
|
||||
|
||||
ew := p.entries.Offset(offset)
|
||||
if ew == nil || !ew.present {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ew
|
||||
}
|
118
transport/hysteria2/core/internal/congestion/bbr/ringbuffer.go
Normal file
118
transport/hysteria2/core/internal/congestion/bbr/ringbuffer.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
package bbr
|
||||
|
||||
// A RingBuffer is a ring buffer.
|
||||
// It acts as a heap that doesn't cause any allocations.
|
||||
type RingBuffer[T any] struct {
|
||||
ring []T
|
||||
headPos, tailPos int
|
||||
full bool
|
||||
}
|
||||
|
||||
// Init preallocs a buffer with a certain size.
|
||||
func (r *RingBuffer[T]) Init(size int) {
|
||||
r.ring = make([]T, size)
|
||||
}
|
||||
|
||||
// Len returns the number of elements in the ring buffer.
|
||||
func (r *RingBuffer[T]) Len() int {
|
||||
if r.full {
|
||||
return len(r.ring)
|
||||
}
|
||||
if r.tailPos >= r.headPos {
|
||||
return r.tailPos - r.headPos
|
||||
}
|
||||
return r.tailPos - r.headPos + len(r.ring)
|
||||
}
|
||||
|
||||
// Empty says if the ring buffer is empty.
|
||||
func (r *RingBuffer[T]) Empty() bool {
|
||||
return !r.full && r.headPos == r.tailPos
|
||||
}
|
||||
|
||||
// PushBack adds a new element.
|
||||
// If the ring buffer is full, its capacity is increased first.
|
||||
func (r *RingBuffer[T]) PushBack(t T) {
|
||||
if r.full || len(r.ring) == 0 {
|
||||
r.grow()
|
||||
}
|
||||
r.ring[r.tailPos] = t
|
||||
r.tailPos++
|
||||
if r.tailPos == len(r.ring) {
|
||||
r.tailPos = 0
|
||||
}
|
||||
if r.tailPos == r.headPos {
|
||||
r.full = true
|
||||
}
|
||||
}
|
||||
|
||||
// PopFront returns the next element.
|
||||
// It must not be called when the buffer is empty, that means that
|
||||
// callers might need to check if there are elements in the buffer first.
|
||||
func (r *RingBuffer[T]) PopFront() T {
|
||||
if r.Empty() {
|
||||
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue")
|
||||
}
|
||||
r.full = false
|
||||
t := r.ring[r.headPos]
|
||||
r.ring[r.headPos] = *new(T)
|
||||
r.headPos++
|
||||
if r.headPos == len(r.ring) {
|
||||
r.headPos = 0
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// Offset returns the offset element.
|
||||
// It must not be called when the buffer is empty, that means that
|
||||
// callers might need to check if there are elements in the buffer first
|
||||
// and check if the index larger than buffer length.
|
||||
func (r *RingBuffer[T]) Offset(index int) *T {
|
||||
if r.Empty() || index >= r.Len() {
|
||||
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: offset from invalid index")
|
||||
}
|
||||
offset := (r.headPos + index) % len(r.ring)
|
||||
return &r.ring[offset]
|
||||
}
|
||||
|
||||
// Front returns the front element.
|
||||
// It must not be called when the buffer is empty, that means that
|
||||
// callers might need to check if there are elements in the buffer first.
|
||||
func (r *RingBuffer[T]) Front() *T {
|
||||
if r.Empty() {
|
||||
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: front from an empty queue")
|
||||
}
|
||||
return &r.ring[r.headPos]
|
||||
}
|
||||
|
||||
// Back returns the back element.
|
||||
// It must not be called when the buffer is empty, that means that
|
||||
// callers might need to check if there are elements in the buffer first.
|
||||
func (r *RingBuffer[T]) Back() *T {
|
||||
if r.Empty() {
|
||||
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: back from an empty queue")
|
||||
}
|
||||
return r.Offset(r.Len() - 1)
|
||||
}
|
||||
|
||||
// Grow the maximum size of the queue.
|
||||
// This method assume the queue is full.
|
||||
func (r *RingBuffer[T]) grow() {
|
||||
oldRing := r.ring
|
||||
newSize := len(oldRing) * 2
|
||||
if newSize == 0 {
|
||||
newSize = 1
|
||||
}
|
||||
r.ring = make([]T, newSize)
|
||||
headLen := copy(r.ring, oldRing[r.headPos:])
|
||||
copy(r.ring[headLen:], oldRing[:r.headPos])
|
||||
r.headPos, r.tailPos, r.full = 0, len(oldRing), false
|
||||
}
|
||||
|
||||
// Clear removes all elements.
|
||||
func (r *RingBuffer[T]) Clear() {
|
||||
var zeroValue T
|
||||
for i := range r.ring {
|
||||
r.ring[i] = zeroValue
|
||||
}
|
||||
r.headPos, r.tailPos, r.full = 0, 0, false
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
package bbr
|
||||
|
||||
import (
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
// Implements Kathleen Nichols' algorithm for tracking the minimum (or maximum)
|
||||
// estimate of a stream of samples over some fixed time interval. (E.g.,
|
||||
// the minimum RTT over the past five minutes.) The algorithm keeps track of
|
||||
// the best, second best, and third best min (or max) estimates, maintaining an
|
||||
// invariant that the measurement time of the n'th best >= n-1'th best.
|
||||
|
||||
// The algorithm works as follows. On a reset, all three estimates are set to
|
||||
// the same sample. The second best estimate is then recorded in the second
|
||||
// quarter of the window, and a third best estimate is recorded in the second
|
||||
// half of the window, bounding the worst case error when the true min is
|
||||
// monotonically increasing (or true max is monotonically decreasing) over the
|
||||
// window.
|
||||
//
|
||||
// A new best sample replaces all three estimates, since the new best is lower
|
||||
// (or higher) than everything else in the window and it is the most recent.
|
||||
// The window thus effectively gets reset on every new min. The same property
|
||||
// holds true for second best and third best estimates. Specifically, when a
|
||||
// sample arrives that is better than the second best but not better than the
|
||||
// best, it replaces the second and third best estimates but not the best
|
||||
// estimate. Similarly, a sample that is better than the third best estimate
|
||||
// but not the other estimates replaces only the third best estimate.
|
||||
//
|
||||
// Finally, when the best expires, it is replaced by the second best, which in
|
||||
// turn is replaced by the third best. The newest sample replaces the third
|
||||
// best.
|
||||
|
||||
type WindowedFilterValue interface {
|
||||
any
|
||||
}
|
||||
|
||||
type WindowedFilterTime interface {
|
||||
constraints.Integer | constraints.Float
|
||||
}
|
||||
|
||||
type WindowedFilter[V WindowedFilterValue, T WindowedFilterTime] struct {
|
||||
// Time length of window.
|
||||
windowLength T
|
||||
estimates []entry[V, T]
|
||||
comparator func(V, V) int
|
||||
}
|
||||
|
||||
type entry[V WindowedFilterValue, T WindowedFilterTime] struct {
|
||||
sample V
|
||||
time T
|
||||
}
|
||||
|
||||
// Compares two values and returns true if the first is greater than or equal
|
||||
// to the second.
|
||||
func MaxFilter[O constraints.Ordered](a, b O) int {
|
||||
if a > b {
|
||||
return 1
|
||||
} else if a < b {
|
||||
return -1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Compares two values and returns true if the first is less than or equal
|
||||
// to the second.
|
||||
func MinFilter[O constraints.Ordered](a, b O) int {
|
||||
if a < b {
|
||||
return 1
|
||||
} else if a > b {
|
||||
return -1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func NewWindowedFilter[V WindowedFilterValue, T WindowedFilterTime](windowLength T, comparator func(V, V) int) *WindowedFilter[V, T] {
|
||||
return &WindowedFilter[V, T]{
|
||||
windowLength: windowLength,
|
||||
estimates: make([]entry[V, T], 3, 3),
|
||||
comparator: comparator,
|
||||
}
|
||||
}
|
||||
|
||||
// Changes the window length. Does not update any current samples.
|
||||
func (f *WindowedFilter[V, T]) SetWindowLength(windowLength T) {
|
||||
f.windowLength = windowLength
|
||||
}
|
||||
|
||||
func (f *WindowedFilter[V, T]) GetBest() V {
|
||||
return f.estimates[0].sample
|
||||
}
|
||||
|
||||
func (f *WindowedFilter[V, T]) GetSecondBest() V {
|
||||
return f.estimates[1].sample
|
||||
}
|
||||
|
||||
func (f *WindowedFilter[V, T]) GetThirdBest() V {
|
||||
return f.estimates[2].sample
|
||||
}
|
||||
|
||||
// Updates best estimates with |sample|, and expires and updates best
|
||||
// estimates as necessary.
|
||||
func (f *WindowedFilter[V, T]) Update(newSample V, newTime T) {
|
||||
// Reset all estimates if they have not yet been initialized, if new sample
|
||||
// is a new best, or if the newest recorded estimate is too old.
|
||||
if f.comparator(f.estimates[0].sample, *new(V)) == 0 ||
|
||||
f.comparator(newSample, f.estimates[0].sample) >= 0 ||
|
||||
newTime-f.estimates[2].time > f.windowLength {
|
||||
f.Reset(newSample, newTime)
|
||||
return
|
||||
}
|
||||
|
||||
if f.comparator(newSample, f.estimates[1].sample) >= 0 {
|
||||
f.estimates[1] = entry[V, T]{newSample, newTime}
|
||||
f.estimates[2] = f.estimates[1]
|
||||
} else if f.comparator(newSample, f.estimates[2].sample) >= 0 {
|
||||
f.estimates[2] = entry[V, T]{newSample, newTime}
|
||||
}
|
||||
|
||||
// Expire and update estimates as necessary.
|
||||
if newTime-f.estimates[0].time > f.windowLength {
|
||||
// The best estimate hasn't been updated for an entire window, so promote
|
||||
// second and third best estimates.
|
||||
f.estimates[0] = f.estimates[1]
|
||||
f.estimates[1] = f.estimates[2]
|
||||
f.estimates[2] = entry[V, T]{newSample, newTime}
|
||||
// Need to iterate one more time. Check if the new best estimate is
|
||||
// outside the window as well, since it may also have been recorded a
|
||||
// long time ago. Don't need to iterate once more since we cover that
|
||||
// case at the beginning of the method.
|
||||
if newTime-f.estimates[0].time > f.windowLength {
|
||||
f.estimates[0] = f.estimates[1]
|
||||
f.estimates[1] = f.estimates[2]
|
||||
}
|
||||
return
|
||||
}
|
||||
if f.comparator(f.estimates[1].sample, f.estimates[0].sample) == 0 &&
|
||||
newTime-f.estimates[1].time > f.windowLength/4 {
|
||||
// A quarter of the window has passed without a better sample, so the
|
||||
// second-best estimate is taken from the second quarter of the window.
|
||||
f.estimates[1] = entry[V, T]{newSample, newTime}
|
||||
f.estimates[2] = f.estimates[1]
|
||||
return
|
||||
}
|
||||
|
||||
if f.comparator(f.estimates[2].sample, f.estimates[1].sample) == 0 &&
|
||||
newTime-f.estimates[2].time > f.windowLength/2 {
|
||||
// We've passed a half of the window without a better estimate, so take
|
||||
// a third-best estimate from the second half of the window.
|
||||
f.estimates[2] = entry[V, T]{newSample, newTime}
|
||||
}
|
||||
}
|
||||
|
||||
// Resets all estimates to new sample.
|
||||
func (f *WindowedFilter[V, T]) Reset(newSample V, newTime T) {
|
||||
f.estimates[2] = entry[V, T]{newSample, newTime}
|
||||
f.estimates[1] = f.estimates[2]
|
||||
f.estimates[0] = f.estimates[1]
|
||||
}
|
||||
|
||||
func (f *WindowedFilter[V, T]) Clear() {
|
||||
f.estimates = make([]entry[V, T], 3, 3)
|
||||
}
|
181
transport/hysteria2/core/internal/congestion/brutal/brutal.go
Normal file
181
transport/hysteria2/core/internal/congestion/brutal/brutal.go
Normal file
|
@ -0,0 +1,181 @@
|
|||
package brutal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion/common"
|
||||
|
||||
"github.com/metacubex/quic-go/congestion"
|
||||
)
|
||||
|
||||
const (
|
||||
pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample
|
||||
minSampleCount = 50
|
||||
minAckRate = 0.8
|
||||
congestionWindowMultiplier = 2
|
||||
|
||||
debugEnv = "HYSTERIA_BRUTAL_DEBUG"
|
||||
debugPrintInterval = 2
|
||||
)
|
||||
|
||||
var _ congestion.CongestionControl = &BrutalSender{}
|
||||
|
||||
type BrutalSender struct {
|
||||
rttStats congestion.RTTStatsProvider
|
||||
bps congestion.ByteCount
|
||||
maxDatagramSize congestion.ByteCount
|
||||
pacer *common.Pacer
|
||||
|
||||
pktInfoSlots [pktInfoSlotCount]pktInfo
|
||||
ackRate float64
|
||||
|
||||
debug bool
|
||||
lastAckPrintTimestamp int64
|
||||
}
|
||||
|
||||
type pktInfo struct {
|
||||
Timestamp int64
|
||||
AckCount uint64
|
||||
LossCount uint64
|
||||
}
|
||||
|
||||
func NewBrutalSender(bps uint64) *BrutalSender {
|
||||
debug, _ := strconv.ParseBool(os.Getenv(debugEnv))
|
||||
bs := &BrutalSender{
|
||||
bps: congestion.ByteCount(bps),
|
||||
maxDatagramSize: congestion.InitialPacketSizeIPv4,
|
||||
ackRate: 1,
|
||||
debug: debug,
|
||||
}
|
||||
bs.pacer = common.NewPacer(func() congestion.ByteCount {
|
||||
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
|
||||
})
|
||||
return bs
|
||||
}
|
||||
|
||||
func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
|
||||
b.rttStats = rttStats
|
||||
}
|
||||
|
||||
func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
|
||||
return b.pacer.TimeUntilSend()
|
||||
}
|
||||
|
||||
func (b *BrutalSender) HasPacingBudget(now time.Time) bool {
|
||||
return b.pacer.Budget(now) >= b.maxDatagramSize
|
||||
}
|
||||
|
||||
func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
||||
return bytesInFlight < b.GetCongestionWindow()
|
||||
}
|
||||
|
||||
func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
|
||||
rtt := b.rttStats.SmoothedRTT()
|
||||
if rtt <= 0 {
|
||||
return 10240
|
||||
}
|
||||
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
|
||||
packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool,
|
||||
) {
|
||||
b.pacer.SentPacket(sentTime, bytes)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
|
||||
priorInFlight congestion.ByteCount, eventTime time.Time,
|
||||
) {
|
||||
// Stub
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount,
|
||||
priorInFlight congestion.ByteCount,
|
||||
) {
|
||||
// Stub
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) {
|
||||
currentTimestamp := eventTime.Unix()
|
||||
slot := currentTimestamp % pktInfoSlotCount
|
||||
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
||||
b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets))
|
||||
b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets))
|
||||
} else {
|
||||
// uninitialized slot or too old, reset
|
||||
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
||||
b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets))
|
||||
b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets))
|
||||
}
|
||||
b.updateAckRate(currentTimestamp)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
|
||||
b.maxDatagramSize = size
|
||||
b.pacer.SetMaxDatagramSize(size)
|
||||
if b.debug {
|
||||
b.debugPrint("SetMaxDatagramSize: %d", size)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
|
||||
minTimestamp := currentTimestamp - pktInfoSlotCount
|
||||
var ackCount, lossCount uint64
|
||||
for _, info := range b.pktInfoSlots {
|
||||
if info.Timestamp < minTimestamp {
|
||||
continue
|
||||
}
|
||||
ackCount += info.AckCount
|
||||
lossCount += info.LossCount
|
||||
}
|
||||
if ackCount+lossCount < minSampleCount {
|
||||
b.ackRate = 1
|
||||
if b.canPrintAckRate(currentTimestamp) {
|
||||
b.lastAckPrintTimestamp = currentTimestamp
|
||||
b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)",
|
||||
ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
|
||||
}
|
||||
return
|
||||
}
|
||||
rate := float64(ackCount) / float64(ackCount+lossCount)
|
||||
if rate < minAckRate {
|
||||
b.ackRate = minAckRate
|
||||
if b.canPrintAckRate(currentTimestamp) {
|
||||
b.lastAckPrintTimestamp = currentTimestamp
|
||||
b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
|
||||
rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
|
||||
}
|
||||
return
|
||||
}
|
||||
b.ackRate = rate
|
||||
if b.canPrintAckRate(currentTimestamp) {
|
||||
b.lastAckPrintTimestamp = currentTimestamp
|
||||
b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
|
||||
rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BrutalSender) InSlowStart() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *BrutalSender) InRecovery() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *BrutalSender) MaybeExitSlowStart() {}
|
||||
|
||||
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
|
||||
|
||||
func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool {
|
||||
return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval
|
||||
}
|
||||
|
||||
func (b *BrutalSender) debugPrint(format string, a ...any) {
|
||||
fmt.Printf("[BrutalSender] [%s] %s\n",
|
||||
time.Now().Format("15:04:05"),
|
||||
fmt.Sprintf(format, a...))
|
||||
}
|
95
transport/hysteria2/core/internal/congestion/common/pacer.go
Normal file
95
transport/hysteria2/core/internal/congestion/common/pacer.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/quic-go/congestion"
|
||||
)
|
||||
|
||||
const (
|
||||
maxBurstPackets = 10
|
||||
)
|
||||
|
||||
// Pacer implements a token bucket pacing algorithm.
|
||||
type Pacer struct {
|
||||
budgetAtLastSent congestion.ByteCount
|
||||
maxDatagramSize congestion.ByteCount
|
||||
lastSentTime time.Time
|
||||
getBandwidth func() congestion.ByteCount // in bytes/s
|
||||
}
|
||||
|
||||
func NewPacer(getBandwidth func() congestion.ByteCount) *Pacer {
|
||||
p := &Pacer{
|
||||
budgetAtLastSent: maxBurstPackets * congestion.InitialPacketSizeIPv4,
|
||||
maxDatagramSize: congestion.InitialPacketSizeIPv4,
|
||||
getBandwidth: getBandwidth,
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
|
||||
budget := p.Budget(sendTime)
|
||||
if size > budget {
|
||||
p.budgetAtLastSent = 0
|
||||
} else {
|
||||
p.budgetAtLastSent = budget - size
|
||||
}
|
||||
p.lastSentTime = sendTime
|
||||
}
|
||||
|
||||
func (p *Pacer) Budget(now time.Time) congestion.ByteCount {
|
||||
if p.lastSentTime.IsZero() {
|
||||
return p.maxBurstSize()
|
||||
}
|
||||
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
||||
if budget < 0 { // protect against overflows
|
||||
budget = congestion.ByteCount(1<<62 - 1)
|
||||
}
|
||||
return minByteCount(p.maxBurstSize(), budget)
|
||||
}
|
||||
|
||||
func (p *Pacer) maxBurstSize() congestion.ByteCount {
|
||||
return maxByteCount(
|
||||
congestion.ByteCount((congestion.MinPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9,
|
||||
maxBurstPackets*p.maxDatagramSize,
|
||||
)
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
// It returns the zero value of time.Time if a packet can be sent immediately.
|
||||
func (p *Pacer) TimeUntilSend() time.Time {
|
||||
if p.budgetAtLastSent >= p.maxDatagramSize {
|
||||
return time.Time{}
|
||||
}
|
||||
return p.lastSentTime.Add(maxDuration(
|
||||
congestion.MinPacingDelay,
|
||||
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/
|
||||
float64(p.getBandwidth())))*time.Nanosecond,
|
||||
))
|
||||
}
|
||||
|
||||
func (p *Pacer) SetMaxDatagramSize(s congestion.ByteCount) {
|
||||
p.maxDatagramSize = s
|
||||
}
|
||||
|
||||
func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
|
||||
if a < b {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func maxDuration(a, b time.Duration) time.Duration {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
18
transport/hysteria2/core/internal/congestion/utils.go
Normal file
18
transport/hysteria2/core/internal/congestion/utils.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion/bbr"
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/congestion/brutal"
|
||||
"github.com/metacubex/quic-go"
|
||||
)
|
||||
|
||||
func UseBBR(conn quic.Connection) {
|
||||
conn.SetCongestionControl(bbr.NewBbrSender(
|
||||
bbr.DefaultClock{},
|
||||
bbr.GetInitialPacketSize(conn.RemoteAddr()),
|
||||
))
|
||||
}
|
||||
|
||||
func UseBrutal(conn quic.Connection, tx uint64) {
|
||||
conn.SetCongestionControl(brutal.NewBrutalSender(tx))
|
||||
}
|
77
transport/hysteria2/core/internal/frag/frag.go
Normal file
77
transport/hysteria2/core/internal/frag/frag.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package frag
|
||||
|
||||
import (
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/internal/protocol"
|
||||
)
|
||||
|
||||
func FragUDPMessage(m *protocol.UDPMessage, maxSize int) []protocol.UDPMessage {
|
||||
if m.Size() <= maxSize {
|
||||
return []protocol.UDPMessage{*m}
|
||||
}
|
||||
fullPayload := m.Data
|
||||
maxPayloadSize := maxSize - m.HeaderSize()
|
||||
off := 0
|
||||
fragID := uint8(0)
|
||||
fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up
|
||||
frags := make([]protocol.UDPMessage, fragCount)
|
||||
for off < len(fullPayload) {
|
||||
payloadSize := len(fullPayload) - off
|
||||
if payloadSize > maxPayloadSize {
|
||||
payloadSize = maxPayloadSize
|
||||
}
|
||||
frag := *m
|
||||
frag.FragID = fragID
|
||||
frag.FragCount = fragCount
|
||||
frag.Data = fullPayload[off : off+payloadSize]
|
||||
frags[fragID] = frag
|
||||
off += payloadSize
|
||||
fragID++
|
||||
}
|
||||
return frags
|
||||
}
|
||||
|
||||
// Defragger handles the defragmentation of UDP messages.
|
||||
// The current implementation can only handle one packet ID at a time.
|
||||
// If another packet arrives before a packet has received all fragments
|
||||
// in their entirety, any previous state is discarded.
|
||||
type Defragger struct {
|
||||
pktID uint16
|
||||
frags []*protocol.UDPMessage
|
||||
count uint8
|
||||
size int // data size
|
||||
}
|
||||
|
||||
func (d *Defragger) Feed(m *protocol.UDPMessage) *protocol.UDPMessage {
|
||||
if m.FragCount <= 1 {
|
||||
return m
|
||||
}
|
||||
if m.FragID >= m.FragCount {
|
||||
// wtf is this?
|
||||
return nil
|
||||
}
|
||||
if m.PacketID != d.pktID || m.FragCount != uint8(len(d.frags)) {
|
||||
// new message, clear previous state
|
||||
d.pktID = m.PacketID
|
||||
d.frags = make([]*protocol.UDPMessage, m.FragCount)
|
||||
d.frags[m.FragID] = m
|
||||
d.count = 1
|
||||
d.size = len(m.Data)
|
||||
} else if d.frags[m.FragID] == nil {
|
||||
d.frags[m.FragID] = m
|
||||
d.count++
|
||||
d.size += len(m.Data)
|
||||
if int(d.count) == len(d.frags) {
|
||||
// all fragments received, assemble
|
||||
data := make([]byte, d.size)
|
||||
off := 0
|
||||
for _, frag := range d.frags {
|
||||
off += copy(data[off:], frag.Data)
|
||||
}
|
||||
m.Data = data
|
||||
m.FragID = 0
|
||||
m.FragCount = 1
|
||||
return m
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
7
transport/hysteria2/core/internal/pmtud/avail.go
Normal file
7
transport/hysteria2/core/internal/pmtud/avail.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
//go:build linux || windows || darwin
|
||||
|
||||
package pmtud
|
||||
|
||||
const (
|
||||
DisablePathMTUDiscovery = false
|
||||
)
|
13
transport/hysteria2/core/internal/pmtud/unavail.go
Normal file
13
transport/hysteria2/core/internal/pmtud/unavail.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
//go:build !linux && !windows && !darwin
|
||||
|
||||
package pmtud
|
||||
|
||||
// quic-go's MTU detection is enabled by default on all platforms.
|
||||
// However, it only actually sets the DF bit on 3 supported platforms (Windows, macOS, Linux).
|
||||
// As a result, on other platforms, probe packets that should never be fragmented will still
|
||||
// be fragmented and transmitted. So we're only enabling it for platforms where we've verified
|
||||
// its functionality for now.
|
||||
|
||||
const (
|
||||
DisablePathMTUDiscovery = true
|
||||
)
|
68
transport/hysteria2/core/internal/protocol/http.go
Normal file
68
transport/hysteria2/core/internal/protocol/http.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
URLHost = "hysteria"
|
||||
URLPath = "/auth"
|
||||
|
||||
RequestHeaderAuth = "Hysteria-Auth"
|
||||
ResponseHeaderUDPEnabled = "Hysteria-UDP"
|
||||
CommonHeaderCCRX = "Hysteria-CC-RX"
|
||||
CommonHeaderPadding = "Hysteria-Padding"
|
||||
|
||||
StatusAuthOK = 233
|
||||
)
|
||||
|
||||
// AuthRequest is what client sends to server for authentication.
|
||||
type AuthRequest struct {
|
||||
Auth string
|
||||
Rx uint64 // 0 = unknown, client asks server to use bandwidth detection
|
||||
}
|
||||
|
||||
// AuthResponse is what server sends to client when authentication is passed.
|
||||
type AuthResponse struct {
|
||||
UDPEnabled bool
|
||||
Rx uint64 // 0 = unlimited
|
||||
RxAuto bool // true = server asks client to use bandwidth detection
|
||||
}
|
||||
|
||||
func AuthRequestFromHeader(h http.Header) AuthRequest {
|
||||
rx, _ := strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64)
|
||||
return AuthRequest{
|
||||
Auth: h.Get(RequestHeaderAuth),
|
||||
Rx: rx,
|
||||
}
|
||||
}
|
||||
|
||||
func AuthRequestToHeader(h http.Header, req AuthRequest) {
|
||||
h.Set(RequestHeaderAuth, req.Auth)
|
||||
h.Set(CommonHeaderCCRX, strconv.FormatUint(req.Rx, 10))
|
||||
h.Set(CommonHeaderPadding, authRequestPadding.String())
|
||||
}
|
||||
|
||||
func AuthResponseFromHeader(h http.Header) AuthResponse {
|
||||
resp := AuthResponse{}
|
||||
resp.UDPEnabled, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled))
|
||||
rxStr := h.Get(CommonHeaderCCRX)
|
||||
if rxStr == "auto" {
|
||||
// Special case for server requesting client to use bandwidth detection
|
||||
resp.RxAuto = true
|
||||
} else {
|
||||
resp.Rx, _ = strconv.ParseUint(rxStr, 10, 64)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func AuthResponseToHeader(h http.Header, resp AuthResponse) {
|
||||
h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(resp.UDPEnabled))
|
||||
if resp.RxAuto {
|
||||
h.Set(CommonHeaderCCRX, "auto")
|
||||
} else {
|
||||
h.Set(CommonHeaderCCRX, strconv.FormatUint(resp.Rx, 10))
|
||||
}
|
||||
h.Set(CommonHeaderPadding, authResponsePadding.String())
|
||||
}
|
31
transport/hysteria2/core/internal/protocol/padding.go
Normal file
31
transport/hysteria2/core/internal/protocol/padding.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
const (
|
||||
paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
)
|
||||
|
||||
// padding specifies a half-open range [Min, Max).
|
||||
type padding struct {
|
||||
Min int
|
||||
Max int
|
||||
}
|
||||
|
||||
func (p padding) String() string {
|
||||
n := p.Min + rand.Intn(p.Max-p.Min)
|
||||
bs := make([]byte, n)
|
||||
for i := range bs {
|
||||
bs[i] = paddingChars[rand.Intn(len(paddingChars))]
|
||||
}
|
||||
return string(bs)
|
||||
}
|
||||
|
||||
var (
|
||||
authRequestPadding = padding{Min: 256, Max: 2048}
|
||||
authResponsePadding = padding{Min: 256, Max: 2048}
|
||||
tcpRequestPadding = padding{Min: 64, Max: 512}
|
||||
tcpResponsePadding = padding{Min: 128, Max: 1024}
|
||||
)
|
255
transport/hysteria2/core/internal/protocol/proxy.go
Normal file
255
transport/hysteria2/core/internal/protocol/proxy.go
Normal file
|
@ -0,0 +1,255 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/metacubex/mihomo/transport/hysteria2/core/errors"
|
||||
|
||||
"github.com/metacubex/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
const (
|
||||
FrameTypeTCPRequest = 0x401
|
||||
|
||||
// Max length values are for preventing DoS attacks
|
||||
|
||||
MaxAddressLength = 2048
|
||||
MaxMessageLength = 2048
|
||||
MaxPaddingLength = 4096
|
||||
|
||||
MaxUDPSize = 4096
|
||||
|
||||
maxVarInt1 = 63
|
||||
maxVarInt2 = 16383
|
||||
maxVarInt4 = 1073741823
|
||||
maxVarInt8 = 4611686018427387903
|
||||
)
|
||||
|
||||
// TCPRequest format:
|
||||
// 0x401 (QUIC varint)
|
||||
// Address length (QUIC varint)
|
||||
// Address (bytes)
|
||||
// Padding length (QUIC varint)
|
||||
// Padding (bytes)
|
||||
|
||||
func ReadTCPRequest(r io.Reader) (string, error) {
|
||||
bReader := quicvarint.NewReader(r)
|
||||
addrLen, err := quicvarint.Read(bReader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if addrLen == 0 || addrLen > MaxAddressLength {
|
||||
return "", errors.ProtocolError{Message: "invalid address length"}
|
||||
}
|
||||
addrBuf := make([]byte, addrLen)
|
||||
_, err = io.ReadFull(r, addrBuf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
paddingLen, err := quicvarint.Read(bReader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if paddingLen > MaxPaddingLength {
|
||||
return "", errors.ProtocolError{Message: "invalid padding length"}
|
||||
}
|
||||
if paddingLen > 0 {
|
||||
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return string(addrBuf), nil
|
||||
}
|
||||
|
||||
func WriteTCPRequest(w io.Writer, addr string) error {
|
||||
padding := tcpRequestPadding.String()
|
||||
paddingLen := len(padding)
|
||||
addrLen := len(addr)
|
||||
sz := int(quicvarint.Len(FrameTypeTCPRequest)) +
|
||||
int(quicvarint.Len(uint64(addrLen))) + addrLen +
|
||||
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
||||
buf := make([]byte, sz)
|
||||
i := varintPut(buf, FrameTypeTCPRequest)
|
||||
i += varintPut(buf[i:], uint64(addrLen))
|
||||
i += copy(buf[i:], addr)
|
||||
i += varintPut(buf[i:], uint64(paddingLen))
|
||||
copy(buf[i:], padding)
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// TCPResponse format:
|
||||
// Status (byte, 0=ok, 1=error)
|
||||
// Message length (QUIC varint)
|
||||
// Message (bytes)
|
||||
// Padding length (QUIC varint)
|
||||
// Padding (bytes)
|
||||
|
||||
func ReadTCPResponse(r io.Reader) (bool, string, error) {
|
||||
var status [1]byte
|
||||
if _, err := io.ReadFull(r, status[:]); err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
bReader := quicvarint.NewReader(r)
|
||||
msgLen, err := quicvarint.Read(bReader)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if msgLen > MaxMessageLength {
|
||||
return false, "", errors.ProtocolError{Message: "invalid message length"}
|
||||
}
|
||||
var msgBuf []byte
|
||||
// No message is fine
|
||||
if msgLen > 0 {
|
||||
msgBuf = make([]byte, msgLen)
|
||||
_, err = io.ReadFull(r, msgBuf)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
}
|
||||
paddingLen, err := quicvarint.Read(bReader)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if paddingLen > MaxPaddingLength {
|
||||
return false, "", errors.ProtocolError{Message: "invalid padding length"}
|
||||
}
|
||||
if paddingLen > 0 {
|
||||
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
}
|
||||
return status[0] == 0, string(msgBuf), nil
|
||||
}
|
||||
|
||||
func WriteTCPResponse(w io.Writer, ok bool, msg string) error {
|
||||
padding := tcpResponsePadding.String()
|
||||
paddingLen := len(padding)
|
||||
msgLen := len(msg)
|
||||
sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
|
||||
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
||||
buf := make([]byte, sz)
|
||||
if ok {
|
||||
buf[0] = 0
|
||||
} else {
|
||||
buf[0] = 1
|
||||
}
|
||||
i := varintPut(buf[1:], uint64(msgLen))
|
||||
i += copy(buf[1+i:], msg)
|
||||
i += varintPut(buf[1+i:], uint64(paddingLen))
|
||||
copy(buf[1+i:], padding)
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// UDPMessage format:
|
||||
// Session ID (uint32 BE)
|
||||
// Packet ID (uint16 BE)
|
||||
// Fragment ID (uint8)
|
||||
// Fragment count (uint8)
|
||||
// Address length (QUIC varint)
|
||||
// Address (bytes)
|
||||
// Data...
|
||||
|
||||
type UDPMessage struct {
|
||||
SessionID uint32 // 4
|
||||
PacketID uint16 // 2
|
||||
FragID uint8 // 1
|
||||
FragCount uint8 // 1
|
||||
Addr string // varint + bytes
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (m *UDPMessage) HeaderSize() int {
|
||||
lAddr := len(m.Addr)
|
||||
return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr
|
||||
}
|
||||
|
||||
func (m *UDPMessage) Size() int {
|
||||
return m.HeaderSize() + len(m.Data)
|
||||
}
|
||||
|
||||
func (m *UDPMessage) Serialize(buf []byte) int {
|
||||
// Make sure the buffer is big enough
|
||||
if len(buf) < m.Size() {
|
||||
return -1
|
||||
}
|
||||
binary.BigEndian.PutUint32(buf, m.SessionID)
|
||||
binary.BigEndian.PutUint16(buf[4:], m.PacketID)
|
||||
buf[6] = m.FragID
|
||||
buf[7] = m.FragCount
|
||||
i := varintPut(buf[8:], uint64(len(m.Addr)))
|
||||
i += copy(buf[8+i:], m.Addr)
|
||||
i += copy(buf[8+i:], m.Data)
|
||||
return 8 + i
|
||||
}
|
||||
|
||||
func ParseUDPMessage(msg []byte) (*UDPMessage, error) {
|
||||
m := &UDPMessage{}
|
||||
buf := bytes.NewBuffer(msg)
|
||||
if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lAddr, err := quicvarint.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lAddr == 0 || lAddr > MaxMessageLength {
|
||||
return nil, errors.ProtocolError{Message: "invalid address length"}
|
||||
}
|
||||
bs := buf.Bytes()
|
||||
if len(bs) <= int(lAddr) {
|
||||
// We use <= instead of < here as we expect at least one byte of data after the address
|
||||
return nil, errors.ProtocolError{Message: "invalid message length"}
|
||||
}
|
||||
m.Addr = string(bs[:lAddr])
|
||||
m.Data = bs[lAddr:]
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// varintPut is like quicvarint.Append, but instead of appending to a slice,
|
||||
// it writes to a fixed-size buffer. Returns the number of bytes written.
|
||||
func varintPut(b []byte, i uint64) int {
|
||||
if i <= maxVarInt1 {
|
||||
b[0] = uint8(i)
|
||||
return 1
|
||||
}
|
||||
if i <= maxVarInt2 {
|
||||
b[0] = uint8(i>>8) | 0x40
|
||||
b[1] = uint8(i)
|
||||
return 2
|
||||
}
|
||||
if i <= maxVarInt4 {
|
||||
b[0] = uint8(i>>24) | 0x80
|
||||
b[1] = uint8(i >> 16)
|
||||
b[2] = uint8(i >> 8)
|
||||
b[3] = uint8(i)
|
||||
return 4
|
||||
}
|
||||
if i <= maxVarInt8 {
|
||||
b[0] = uint8(i>>56) | 0xc0
|
||||
b[1] = uint8(i >> 48)
|
||||
b[2] = uint8(i >> 40)
|
||||
b[3] = uint8(i >> 32)
|
||||
b[4] = uint8(i >> 24)
|
||||
b[5] = uint8(i >> 16)
|
||||
b[6] = uint8(i >> 8)
|
||||
b[7] = uint8(i)
|
||||
return 8
|
||||
}
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
|
||||
}
|
24
transport/hysteria2/core/internal/utils/atomic.go
Normal file
24
transport/hysteria2/core/internal/utils/atomic.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AtomicTime struct {
|
||||
v atomic.Value
|
||||
}
|
||||
|
||||
func NewAtomicTime(t time.Time) *AtomicTime {
|
||||
a := &AtomicTime{}
|
||||
a.Set(t)
|
||||
return a
|
||||
}
|
||||
|
||||
func (t *AtomicTime) Set(new time.Time) {
|
||||
t.v.Store(new)
|
||||
}
|
||||
|
||||
func (t *AtomicTime) Get() time.Time {
|
||||
return t.v.Load().(time.Time)
|
||||
}
|
62
transport/hysteria2/core/internal/utils/qstream.go
Normal file
62
transport/hysteria2/core/internal/utils/qstream.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/quic-go"
|
||||
)
|
||||
|
||||
// QStream is a wrapper of quic.Stream that handles Close() in a way that
|
||||
// makes more sense to us. By default, quic.Stream's Close() only closes
|
||||
// the write side of the stream, not the read side. And if there is unread
|
||||
// data, the stream is not really considered closed until either the data
|
||||
// is drained or CancelRead() is called.
|
||||
// References:
|
||||
// - https://github.com/libp2p/go-libp2p/blob/master/p2p/transport/quic/stream.go
|
||||
// - https://github.com/quic-go/quic-go/issues/3558
|
||||
// - https://github.com/quic-go/quic-go/issues/1599
|
||||
type QStream struct {
|
||||
Stream quic.Stream
|
||||
}
|
||||
|
||||
func (s *QStream) StreamID() quic.StreamID {
|
||||
return s.Stream.StreamID()
|
||||
}
|
||||
|
||||
func (s *QStream) Read(p []byte) (n int, err error) {
|
||||
return s.Stream.Read(p)
|
||||
}
|
||||
|
||||
func (s *QStream) CancelRead(code quic.StreamErrorCode) {
|
||||
s.Stream.CancelRead(code)
|
||||
}
|
||||
|
||||
func (s *QStream) SetReadDeadline(t time.Time) error {
|
||||
return s.Stream.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (s *QStream) Write(p []byte) (n int, err error) {
|
||||
return s.Stream.Write(p)
|
||||
}
|
||||
|
||||
func (s *QStream) Close() error {
|
||||
s.Stream.CancelRead(0)
|
||||
return s.Stream.Close()
|
||||
}
|
||||
|
||||
func (s *QStream) CancelWrite(code quic.StreamErrorCode) {
|
||||
s.Stream.CancelWrite(code)
|
||||
}
|
||||
|
||||
func (s *QStream) Context() context.Context {
|
||||
return s.Stream.Context()
|
||||
}
|
||||
|
||||
func (s *QStream) SetWriteDeadline(t time.Time) error {
|
||||
return s.Stream.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (s *QStream) SetDeadline(t time.Time) error {
|
||||
return s.Stream.SetDeadline(t)
|
||||
}
|
92
transport/hysteria2/extras/correctnet/correctnet.go
Normal file
92
transport/hysteria2/extras/correctnet/correctnet.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package correctnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func extractIPFamily(ip net.IP) (family string) {
|
||||
if len(ip) == 0 {
|
||||
// real family independent wildcard address, such as ":443"
|
||||
return ""
|
||||
}
|
||||
if p4 := ip.To4(); len(p4) == net.IPv4len {
|
||||
return "4"
|
||||
}
|
||||
return "6"
|
||||
}
|
||||
|
||||
func tcpAddrNetwork(addr *net.TCPAddr) (network string) {
|
||||
if addr == nil {
|
||||
return "tcp"
|
||||
}
|
||||
return "tcp" + extractIPFamily(addr.IP)
|
||||
}
|
||||
|
||||
func udpAddrNetwork(addr *net.UDPAddr) (network string) {
|
||||
if addr == nil {
|
||||
return "udp"
|
||||
}
|
||||
return "udp" + extractIPFamily(addr.IP)
|
||||
}
|
||||
|
||||
func ipAddrNetwork(addr *net.IPAddr) (network string) {
|
||||
if addr == nil {
|
||||
return "ip"
|
||||
}
|
||||
return "ip" + extractIPFamily(addr.IP)
|
||||
}
|
||||
|
||||
func Listen(network, address string) (net.Listener, error) {
|
||||
if network == "tcp" {
|
||||
tcpAddr, err := net.ResolveTCPAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ListenTCP(network, tcpAddr)
|
||||
}
|
||||
return net.Listen(network, address)
|
||||
}
|
||||
|
||||
func ListenTCP(network string, laddr *net.TCPAddr) (*net.TCPListener, error) {
|
||||
if network == "tcp" {
|
||||
return net.ListenTCP(tcpAddrNetwork(laddr), laddr)
|
||||
}
|
||||
return net.ListenTCP(network, laddr)
|
||||
}
|
||||
|
||||
func ListenPacket(network, address string) (listener net.PacketConn, err error) {
|
||||
if network == "udp" {
|
||||
udpAddr, err := net.ResolveUDPAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ListenUDP(network, udpAddr)
|
||||
}
|
||||
if strings.HasPrefix(network, "ip:") {
|
||||
proto := network[3:]
|
||||
ipAddr, err := net.ResolveIPAddr(proto, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.ListenIP(ipAddrNetwork(ipAddr)+":"+proto, ipAddr)
|
||||
}
|
||||
return net.ListenPacket(network, address)
|
||||
}
|
||||
|
||||
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
if network == "udp" {
|
||||
return net.ListenUDP(udpAddrNetwork(laddr), laddr)
|
||||
}
|
||||
return net.ListenUDP(network, laddr)
|
||||
}
|
||||
|
||||
func HTTPListenAndServe(address string, handler http.Handler) error {
|
||||
listener, err := Listen("tcp", address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listener.Close()
|
||||
return http.Serve(listener, handler)
|
||||
}
|
121
transport/hysteria2/extras/obfs/conn.go
Normal file
121
transport/hysteria2/extras/obfs/conn.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
package obfs
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough
|
||||
|
||||
// Obfuscator is the interface that wraps the Obfuscate and Deobfuscate methods.
|
||||
// Both methods return the number of bytes written to out.
|
||||
// If a packet is not valid, the methods should return 0.
|
||||
type Obfuscator interface {
|
||||
Obfuscate(in, out []byte) int
|
||||
Deobfuscate(in, out []byte) int
|
||||
}
|
||||
|
||||
var _ net.PacketConn = (*obfsPacketConn)(nil)
|
||||
|
||||
type obfsPacketConn struct {
|
||||
Conn net.PacketConn
|
||||
Obfs Obfuscator
|
||||
|
||||
readBuf []byte
|
||||
readMutex sync.Mutex
|
||||
writeBuf []byte
|
||||
writeMutex sync.Mutex
|
||||
}
|
||||
|
||||
// obfsPacketConnUDP is a special case of obfsPacketConn that uses a UDPConn
|
||||
// as the underlying connection. We pass additional methods to quic-go to
|
||||
// enable UDP-specific optimizations.
|
||||
type obfsPacketConnUDP struct {
|
||||
*obfsPacketConn
|
||||
UDPConn *net.UDPConn
|
||||
}
|
||||
|
||||
// WrapPacketConn enables obfuscation on a net.PacketConn.
|
||||
// The obfuscation is transparent to the caller - the n bytes returned by
|
||||
// ReadFrom and WriteTo are the number of original bytes, not after
|
||||
// obfuscation/deobfuscation.
|
||||
func WrapPacketConn(conn net.PacketConn, obfs Obfuscator) net.PacketConn {
|
||||
opc := &obfsPacketConn{
|
||||
Conn: conn,
|
||||
Obfs: obfs,
|
||||
readBuf: make([]byte, udpBufferSize),
|
||||
writeBuf: make([]byte, udpBufferSize),
|
||||
}
|
||||
if udpConn, ok := conn.(*net.UDPConn); ok {
|
||||
return &obfsPacketConnUDP{
|
||||
obfsPacketConn: opc,
|
||||
UDPConn: udpConn,
|
||||
}
|
||||
} else {
|
||||
return opc
|
||||
}
|
||||
}
|
||||
|
||||
func (c *obfsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
for {
|
||||
c.readMutex.Lock()
|
||||
n, addr, err = c.Conn.ReadFrom(c.readBuf)
|
||||
if n <= 0 {
|
||||
c.readMutex.Unlock()
|
||||
return
|
||||
}
|
||||
n = c.Obfs.Deobfuscate(c.readBuf[:n], p)
|
||||
c.readMutex.Unlock()
|
||||
if n > 0 || err != nil {
|
||||
return
|
||||
}
|
||||
// Invalid packet, try again
|
||||
}
|
||||
}
|
||||
|
||||
func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
c.writeMutex.Lock()
|
||||
nn := c.Obfs.Obfuscate(p, c.writeBuf)
|
||||
_, err = c.Conn.WriteTo(c.writeBuf[:nn], addr)
|
||||
c.writeMutex.Unlock()
|
||||
if err == nil {
|
||||
n = len(p)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *obfsPacketConn) Close() error {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c *obfsPacketConn) LocalAddr() net.Addr {
|
||||
return c.Conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *obfsPacketConn) SetDeadline(t time.Time) error {
|
||||
return c.Conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *obfsPacketConn) SetReadDeadline(t time.Time) error {
|
||||
return c.Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.Conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// UDP-specific methods below
|
||||
|
||||
func (c *obfsPacketConnUDP) SetReadBuffer(bytes int) error {
|
||||
return c.UDPConn.SetReadBuffer(bytes)
|
||||
}
|
||||
|
||||
func (c *obfsPacketConnUDP) SetWriteBuffer(bytes int) error {
|
||||
return c.UDPConn.SetWriteBuffer(bytes)
|
||||
}
|
||||
|
||||
func (c *obfsPacketConnUDP) SyscallConn() (syscall.RawConn, error) {
|
||||
return c.UDPConn.SyscallConn()
|
||||
}
|
71
transport/hysteria2/extras/obfs/salamander.go
Normal file
71
transport/hysteria2/extras/obfs/salamander.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package obfs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/blake2b"
|
||||
)
|
||||
|
||||
const (
|
||||
smPSKMinLen = 4
|
||||
smSaltLen = 8
|
||||
smKeyLen = blake2b.Size256
|
||||
)
|
||||
|
||||
var _ Obfuscator = (*SalamanderObfuscator)(nil)
|
||||
|
||||
var ErrPSKTooShort = fmt.Errorf("PSK must be at least %d bytes", smPSKMinLen)
|
||||
|
||||
// SalamanderObfuscator is an obfuscator that obfuscates each packet with
|
||||
// the BLAKE2b-256 hash of a pre-shared key combined with a random salt.
|
||||
// Packet format: [8-byte salt][payload]
|
||||
type SalamanderObfuscator struct {
|
||||
PSK []byte
|
||||
RandSrc *rand.Rand
|
||||
|
||||
lk sync.Mutex
|
||||
}
|
||||
|
||||
func NewSalamanderObfuscator(psk []byte) (*SalamanderObfuscator, error) {
|
||||
if len(psk) < smPSKMinLen {
|
||||
return nil, ErrPSKTooShort
|
||||
}
|
||||
return &SalamanderObfuscator{
|
||||
PSK: psk,
|
||||
RandSrc: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *SalamanderObfuscator) Obfuscate(in, out []byte) int {
|
||||
outLen := len(in) + smSaltLen
|
||||
if len(out) < outLen {
|
||||
return 0
|
||||
}
|
||||
o.lk.Lock()
|
||||
_, _ = o.RandSrc.Read(out[:smSaltLen])
|
||||
o.lk.Unlock()
|
||||
key := o.key(out[:smSaltLen])
|
||||
for i, c := range in {
|
||||
out[i+smSaltLen] = c ^ key[i%smKeyLen]
|
||||
}
|
||||
return outLen
|
||||
}
|
||||
|
||||
func (o *SalamanderObfuscator) Deobfuscate(in, out []byte) int {
|
||||
outLen := len(in) - smSaltLen
|
||||
if outLen <= 0 || len(out) < outLen {
|
||||
return 0
|
||||
}
|
||||
key := o.key(in[:smSaltLen])
|
||||
for i, c := range in[smSaltLen:] {
|
||||
out[i] = c ^ key[i%smKeyLen]
|
||||
}
|
||||
return outLen
|
||||
}
|
||||
|
||||
func (o *SalamanderObfuscator) key(salt []byte) [smKeyLen]byte {
|
||||
return blake2b.Sum256(append(o.PSK, salt...))
|
||||
}
|
45
transport/hysteria2/extras/obfs/salamander_test.go
Normal file
45
transport/hysteria2/extras/obfs/salamander_test.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package obfs
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func BenchmarkSalamanderObfuscator_Obfuscate(b *testing.B) {
|
||||
o, _ := NewSalamanderObfuscator([]byte("average_password"))
|
||||
in := make([]byte, 1200)
|
||||
_, _ = rand.Read(in)
|
||||
out := make([]byte, 2048)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
o.Obfuscate(in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSalamanderObfuscator_Deobfuscate(b *testing.B) {
|
||||
o, _ := NewSalamanderObfuscator([]byte("average_password"))
|
||||
in := make([]byte, 1200)
|
||||
_, _ = rand.Read(in)
|
||||
out := make([]byte, 2048)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
o.Deobfuscate(in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSalamanderObfuscator(t *testing.T) {
|
||||
o, _ := NewSalamanderObfuscator([]byte("average_password"))
|
||||
in := make([]byte, 1200)
|
||||
oOut := make([]byte, 2048)
|
||||
dOut := make([]byte, 2048)
|
||||
for i := 0; i < 1000; i++ {
|
||||
_, _ = rand.Read(in)
|
||||
n := o.Obfuscate(in, oOut)
|
||||
assert.Equal(t, len(in)+smSaltLen, n)
|
||||
n = o.Deobfuscate(oOut[:n], dOut)
|
||||
assert.Equal(t, len(in), n)
|
||||
assert.Equal(t, in, dOut[:n])
|
||||
}
|
||||
}
|
92
transport/hysteria2/extras/transport/udphop/addr.go
Normal file
92
transport/hysteria2/extras/transport/udphop/addr.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package udphop
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type InvalidPortError struct {
|
||||
PortStr string
|
||||
}
|
||||
|
||||
func (e InvalidPortError) Error() string {
|
||||
return fmt.Sprintf("%s is not a valid port number or range", e.PortStr)
|
||||
}
|
||||
|
||||
// UDPHopAddr contains an IP address and a list of ports.
|
||||
type UDPHopAddr struct {
|
||||
IP net.IP
|
||||
Ports []uint16
|
||||
PortStr string
|
||||
}
|
||||
|
||||
func (a *UDPHopAddr) Network() string {
|
||||
return "udphop"
|
||||
}
|
||||
|
||||
func (a *UDPHopAddr) String() string {
|
||||
return net.JoinHostPort(a.IP.String(), a.PortStr)
|
||||
}
|
||||
|
||||
// addrs returns a list of net.Addr's, one for each port.
|
||||
func (a *UDPHopAddr) addrs() ([]net.Addr, error) {
|
||||
var addrs []net.Addr
|
||||
for _, port := range a.Ports {
|
||||
addr := &net.UDPAddr{
|
||||
IP: a.IP,
|
||||
Port: int(port),
|
||||
}
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func ResolveUDPHopAddr(addr string) (*UDPHopAddr, error) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ip, err := net.ResolveIPAddr("ip", host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := &UDPHopAddr{
|
||||
IP: ip.IP,
|
||||
PortStr: portStr,
|
||||
}
|
||||
|
||||
portStrs := strings.Split(portStr, ",")
|
||||
for _, portStr := range portStrs {
|
||||
if strings.Contains(portStr, "-") {
|
||||
// Port range
|
||||
portRange := strings.Split(portStr, "-")
|
||||
if len(portRange) != 2 {
|
||||
return nil, InvalidPortError{portStr}
|
||||
}
|
||||
start, err := strconv.ParseUint(portRange[0], 10, 16)
|
||||
if err != nil {
|
||||
return nil, InvalidPortError{portStr}
|
||||
}
|
||||
end, err := strconv.ParseUint(portRange[1], 10, 16)
|
||||
if err != nil {
|
||||
return nil, InvalidPortError{portStr}
|
||||
}
|
||||
if start > end {
|
||||
start, end = end, start
|
||||
}
|
||||
for i := start; i <= end; i++ {
|
||||
result.Ports = append(result.Ports, uint16(i))
|
||||
}
|
||||
} else {
|
||||
// Single port
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return nil, InvalidPortError{portStr}
|
||||
}
|
||||
result.Ports = append(result.Ports, uint16(port))
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
288
transport/hysteria2/extras/transport/udphop/conn.go
Normal file
288
transport/hysteria2/extras/transport/udphop/conn.go
Normal file
|
@ -0,0 +1,288 @@
|
|||
package udphop
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/log"
|
||||
)
|
||||
|
||||
const (
|
||||
packetQueueSize = 1024
|
||||
udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough
|
||||
|
||||
defaultHopInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
type udpHopPacketConn struct {
|
||||
Addr net.Addr
|
||||
Addrs []net.Addr
|
||||
HopInterval time.Duration
|
||||
|
||||
connMutex sync.RWMutex
|
||||
prevConn net.PacketConn
|
||||
currentConn net.PacketConn
|
||||
addrIndex int
|
||||
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
|
||||
recvQueue chan *udpPacket
|
||||
closeChan chan struct{}
|
||||
closed bool
|
||||
|
||||
bufPool sync.Pool
|
||||
}
|
||||
|
||||
type udpPacket struct {
|
||||
Buf []byte
|
||||
N int
|
||||
Addr net.Addr
|
||||
Err error
|
||||
}
|
||||
|
||||
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration) (net.PacketConn, error) {
|
||||
if hopInterval == 0 {
|
||||
hopInterval = defaultHopInterval
|
||||
} else if hopInterval < 5*time.Second {
|
||||
return nil, errors.New("hop interval must be at least 5 seconds")
|
||||
}
|
||||
addrs, err := addr.addrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
curConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hConn := &udpHopPacketConn{
|
||||
Addr: addr,
|
||||
Addrs: addrs,
|
||||
HopInterval: hopInterval,
|
||||
prevConn: nil,
|
||||
currentConn: curConn,
|
||||
addrIndex: rand.Intn(len(addrs)),
|
||||
recvQueue: make(chan *udpPacket, packetQueueSize),
|
||||
closeChan: make(chan struct{}),
|
||||
bufPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, udpBufferSize)
|
||||
},
|
||||
},
|
||||
}
|
||||
go hConn.recvLoop(curConn)
|
||||
go hConn.hopLoop()
|
||||
return hConn, nil
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) recvLoop(conn net.PacketConn) {
|
||||
for {
|
||||
buf := u.bufPool.Get().([]byte)
|
||||
n, addr, err := conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
u.bufPool.Put(buf)
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
// Only pass through timeout errors here, not permanent errors
|
||||
// like connection closed. Connection close is normal as we close
|
||||
// the old connection to exit this loop every time we hop.
|
||||
u.recvQueue <- &udpPacket{nil, 0, nil, netErr}
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case u.recvQueue <- &udpPacket{buf, n, addr, nil}:
|
||||
// Packet successfully queued
|
||||
default:
|
||||
// Queue is full, drop the packet
|
||||
u.bufPool.Put(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) hopLoop() {
|
||||
ticker := time.NewTicker(u.HopInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
u.hop()
|
||||
case <-u.closeChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) hop() {
|
||||
u.connMutex.Lock()
|
||||
defer u.connMutex.Unlock()
|
||||
if u.closed {
|
||||
return
|
||||
}
|
||||
newConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
// Could be temporary, just skip this hop
|
||||
return
|
||||
}
|
||||
// We need to keep receiving packets from the previous connection,
|
||||
// because otherwise there will be packet loss due to the time gap
|
||||
// between we hop to a new port and the server acknowledges this change.
|
||||
// So we do the following:
|
||||
// Close prevConn,
|
||||
// move currentConn to prevConn,
|
||||
// set newConn as currentConn,
|
||||
// start recvLoop on newConn.
|
||||
if u.prevConn != nil {
|
||||
_ = u.prevConn.Close() // recvLoop for this conn will exit
|
||||
}
|
||||
u.prevConn = u.currentConn
|
||||
u.currentConn = newConn
|
||||
// Set buffer sizes if previously set
|
||||
if u.readBufferSize > 0 {
|
||||
_ = trySetReadBuffer(u.currentConn, u.readBufferSize)
|
||||
}
|
||||
if u.writeBufferSize > 0 {
|
||||
_ = trySetWriteBuffer(u.currentConn, u.writeBufferSize)
|
||||
}
|
||||
go u.recvLoop(newConn)
|
||||
// Update addrIndex to a new random value
|
||||
u.addrIndex = rand.Intn(len(u.Addrs))
|
||||
log.Infoln("hopped to %s", u.Addrs[u.addrIndex])
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
for {
|
||||
select {
|
||||
case p := <-u.recvQueue:
|
||||
if p.Err != nil {
|
||||
return 0, nil, p.Err
|
||||
}
|
||||
// Currently we do not check whether the packet is from
|
||||
// the server or not due to performance reasons.
|
||||
n := copy(b, p.Buf[:p.N])
|
||||
u.bufPool.Put(p.Buf)
|
||||
return n, u.Addr, nil
|
||||
case <-u.closeChan:
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
u.connMutex.RLock()
|
||||
defer u.connMutex.RUnlock()
|
||||
if u.closed {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
// Skip the check for now, always write to the server,
|
||||
// for the same reason as in ReadFrom.
|
||||
return u.currentConn.WriteTo(b, u.Addrs[u.addrIndex])
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) Close() error {
|
||||
u.connMutex.Lock()
|
||||
defer u.connMutex.Unlock()
|
||||
if u.closed {
|
||||
return nil
|
||||
}
|
||||
// Close prevConn and currentConn
|
||||
// Close closeChan to unblock ReadFrom & hopLoop
|
||||
// Set closed flag to true to prevent double close
|
||||
if u.prevConn != nil {
|
||||
_ = u.prevConn.Close()
|
||||
}
|
||||
err := u.currentConn.Close()
|
||||
close(u.closeChan)
|
||||
u.closed = true
|
||||
u.Addrs = nil // For GC
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) LocalAddr() net.Addr {
|
||||
u.connMutex.RLock()
|
||||
defer u.connMutex.RUnlock()
|
||||
return u.currentConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) SetDeadline(t time.Time) error {
|
||||
u.connMutex.RLock()
|
||||
defer u.connMutex.RUnlock()
|
||||
if u.prevConn != nil {
|
||||
_ = u.prevConn.SetDeadline(t)
|
||||
}
|
||||
return u.currentConn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) SetReadDeadline(t time.Time) error {
|
||||
u.connMutex.RLock()
|
||||
defer u.connMutex.RUnlock()
|
||||
if u.prevConn != nil {
|
||||
_ = u.prevConn.SetReadDeadline(t)
|
||||
}
|
||||
return u.currentConn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) SetWriteDeadline(t time.Time) error {
|
||||
u.connMutex.RLock()
|
||||
defer u.connMutex.RUnlock()
|
||||
if u.prevConn != nil {
|
||||
_ = u.prevConn.SetWriteDeadline(t)
|
||||
}
|
||||
return u.currentConn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// UDP-specific methods below
|
||||
|
||||
func (u *udpHopPacketConn) SetReadBuffer(bytes int) error {
|
||||
u.connMutex.Lock()
|
||||
defer u.connMutex.Unlock()
|
||||
u.readBufferSize = bytes
|
||||
if u.prevConn != nil {
|
||||
_ = trySetReadBuffer(u.prevConn, bytes)
|
||||
}
|
||||
return trySetReadBuffer(u.currentConn, bytes)
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) SetWriteBuffer(bytes int) error {
|
||||
u.connMutex.Lock()
|
||||
defer u.connMutex.Unlock()
|
||||
u.writeBufferSize = bytes
|
||||
if u.prevConn != nil {
|
||||
_ = trySetWriteBuffer(u.prevConn, bytes)
|
||||
}
|
||||
return trySetWriteBuffer(u.currentConn, bytes)
|
||||
}
|
||||
|
||||
func (u *udpHopPacketConn) SyscallConn() (syscall.RawConn, error) {
|
||||
u.connMutex.RLock()
|
||||
defer u.connMutex.RUnlock()
|
||||
sc, ok := u.currentConn.(syscall.Conn)
|
||||
if !ok {
|
||||
return nil, errors.New("not supported")
|
||||
}
|
||||
return sc.SyscallConn()
|
||||
}
|
||||
|
||||
func trySetReadBuffer(pc net.PacketConn, bytes int) error {
|
||||
sc, ok := pc.(interface {
|
||||
SetReadBuffer(bytes int) error
|
||||
})
|
||||
if ok {
|
||||
return sc.SetReadBuffer(bytes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func trySetWriteBuffer(pc net.PacketConn, bytes int) error {
|
||||
sc, ok := pc.(interface {
|
||||
SetWriteBuffer(bytes int) error
|
||||
})
|
||||
if ok {
|
||||
return sc.SetWriteBuffer(bytes)
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Reference in a new issue