mirror of
https://github.com/MetaCubeX/Clash.Meta.git
synced 2025-04-04 05:33:35 +03:00
feat: implement anytls client and server (#1844)
This commit is contained in:
parent
ef29e4501e
commit
9962a0d091
21 changed files with 2291 additions and 0 deletions
137
adapter/outbound/anytls.go
Normal file
137
adapter/outbound/anytls.go
Normal file
|
@ -0,0 +1,137 @@
|
|||
package outbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
CN "github.com/metacubex/mihomo/common/net"
|
||||
"github.com/metacubex/mihomo/component/dialer"
|
||||
"github.com/metacubex/mihomo/component/proxydialer"
|
||||
"github.com/metacubex/mihomo/component/resolver"
|
||||
tlsC "github.com/metacubex/mihomo/component/tls"
|
||||
C "github.com/metacubex/mihomo/constant"
|
||||
"github.com/metacubex/mihomo/transport/anytls"
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/uot"
|
||||
)
|
||||
|
||||
type AnyTLS struct {
|
||||
*Base
|
||||
client *anytls.Client
|
||||
dialer proxydialer.SingDialer
|
||||
option *AnyTLSOption
|
||||
}
|
||||
|
||||
type AnyTLSOption struct {
|
||||
BasicOption
|
||||
Name string `proxy:"name"`
|
||||
Server string `proxy:"server"`
|
||||
Port int `proxy:"port"`
|
||||
Password string `proxy:"password"`
|
||||
ALPN []string `proxy:"alpn,omitempty"`
|
||||
SNI string `proxy:"sni,omitempty"`
|
||||
ClientFingerprint string `proxy:"client-fingerprint,omitempty"`
|
||||
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
|
||||
UDP bool `proxy:"udp,omitempty"`
|
||||
IdleSessionCheckInterval int `proxy:"idle-session-check-interval,omitempty"`
|
||||
IdleSessionTimeout int `proxy:"idle-session-timeout,omitempty"`
|
||||
}
|
||||
|
||||
func (t *AnyTLS) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
|
||||
options := t.Base.DialOptions(opts...)
|
||||
t.dialer.SetDialer(dialer.NewDialer(options...))
|
||||
c, err := t.client.CreateProxy(ctx, M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewConn(CN.NewRefConn(c, t), t), nil
|
||||
}
|
||||
|
||||
func (t *AnyTLS) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
|
||||
// create tcp
|
||||
options := t.Base.DialOptions(opts...)
|
||||
t.dialer.SetDialer(dialer.NewDialer(options...))
|
||||
c, err := t.client.CreateProxy(ctx, uot.RequestDestination(2))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create uot on tcp
|
||||
if !metadata.Resolved() {
|
||||
ip, err := resolver.ResolveIP(ctx, metadata.Host)
|
||||
if err != nil {
|
||||
return nil, errors.New("can't resolve ip")
|
||||
}
|
||||
metadata.DstIP = ip
|
||||
}
|
||||
destination := M.SocksaddrFromNet(metadata.UDPAddr())
|
||||
return newPacketConn(CN.NewRefPacketConn(CN.NewThreadSafePacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination})), t), t), nil
|
||||
}
|
||||
|
||||
// SupportUOT implements C.ProxyAdapter
|
||||
func (t *AnyTLS) SupportUOT() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ProxyInfo implements C.ProxyAdapter
|
||||
func (t *AnyTLS) ProxyInfo() C.ProxyInfo {
|
||||
info := t.Base.ProxyInfo()
|
||||
info.DialerProxy = t.option.DialerProxy
|
||||
return info
|
||||
}
|
||||
|
||||
func NewAnyTLS(option AnyTLSOption) (*AnyTLS, error) {
|
||||
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
|
||||
|
||||
singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer())
|
||||
|
||||
tOption := anytls.ClientConfig{
|
||||
Password: option.Password,
|
||||
Server: M.ParseSocksaddrHostPort(option.Server, uint16(option.Port)),
|
||||
Dialer: singDialer,
|
||||
IdleSessionCheckInterval: time.Duration(option.IdleSessionCheckInterval) * time.Second,
|
||||
IdleSessionTimeout: time.Duration(option.IdleSessionTimeout) * time.Second,
|
||||
ClientFingerprint: option.ClientFingerprint,
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: option.SNI,
|
||||
InsecureSkipVerify: option.SkipCertVerify,
|
||||
NextProtos: option.ALPN,
|
||||
}
|
||||
if tlsConfig.ServerName == "" {
|
||||
tlsConfig.ServerName = "127.0.0.1"
|
||||
}
|
||||
tOption.TLSConfig = tlsConfig
|
||||
|
||||
if tlsC.HaveGlobalFingerprint() && len(tOption.ClientFingerprint) == 0 {
|
||||
tOption.ClientFingerprint = tlsC.GetGlobalFingerprint()
|
||||
}
|
||||
|
||||
outbound := &AnyTLS{
|
||||
Base: &Base{
|
||||
name: option.Name,
|
||||
addr: addr,
|
||||
tp: C.AnyTLS,
|
||||
udp: option.UDP,
|
||||
tfo: option.TFO,
|
||||
mpTcp: option.MPTCP,
|
||||
iface: option.Interface,
|
||||
rmark: option.RoutingMark,
|
||||
prefer: C.NewDNSPrefer(option.IPVersion),
|
||||
},
|
||||
client: anytls.NewClient(context.TODO(), tOption),
|
||||
option: &option,
|
||||
dialer: singDialer,
|
||||
}
|
||||
runtime.SetFinalizer(outbound, func(o *AnyTLS) {
|
||||
common.Close(o.client)
|
||||
})
|
||||
|
||||
return outbound, nil
|
||||
}
|
|
@ -148,6 +148,13 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) {
|
|||
break
|
||||
}
|
||||
proxy, err = outbound.NewMieru(*mieruOption)
|
||||
case "anytls":
|
||||
anytlsOption := &outbound.AnyTLSOption{}
|
||||
err = decoder.Decode(mapping, anytlsOption)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
proxy, err = outbound.NewAnyTLS(*anytlsOption)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupport proxy type: %s", proxyType)
|
||||
}
|
||||
|
|
|
@ -43,6 +43,7 @@ const (
|
|||
Tuic
|
||||
Ssh
|
||||
Mieru
|
||||
AnyTLS
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -229,6 +230,8 @@ func (at AdapterType) String() string {
|
|||
return "Ssh"
|
||||
case Mieru:
|
||||
return "Mieru"
|
||||
case AnyTLS:
|
||||
return "AnyTLS"
|
||||
case Relay:
|
||||
return "Relay"
|
||||
case Selector:
|
||||
|
|
|
@ -32,6 +32,7 @@ const (
|
|||
TUN
|
||||
TUIC
|
||||
HYSTERIA2
|
||||
ANYTLS
|
||||
INNER
|
||||
)
|
||||
|
||||
|
@ -84,6 +85,8 @@ func (t Type) String() string {
|
|||
return "Tuic"
|
||||
case HYSTERIA2:
|
||||
return "Hysteria2"
|
||||
case ANYTLS:
|
||||
return "AnyTLS"
|
||||
case INNER:
|
||||
return "Inner"
|
||||
default:
|
||||
|
@ -120,6 +123,8 @@ func ParseType(t string) (*Type, error) {
|
|||
res = TUIC
|
||||
case "HYSTERIA2":
|
||||
res = HYSTERIA2
|
||||
case "ANYTLS":
|
||||
res = ANYTLS
|
||||
case "INNER":
|
||||
res = INNER
|
||||
default:
|
||||
|
|
|
@ -864,6 +864,22 @@ proxies: # socks5
|
|||
# 可以使用的值包括 MULTIPLEXING_OFF, MULTIPLEXING_LOW, MULTIPLEXING_MIDDLE, MULTIPLEXING_HIGH。其中 MULTIPLEXING_OFF 会关闭多路复用功能。默认值为 MULTIPLEXING_LOW。
|
||||
# multiplexing: MULTIPLEXING_LOW
|
||||
|
||||
# anytls
|
||||
- name: anytls
|
||||
type: anytls
|
||||
server: 1.2.3.4
|
||||
port: 443
|
||||
password: "<your password>"
|
||||
# client-fingerprint: chrome
|
||||
udp: true
|
||||
# idle-session-check-interval: 30 # seconds
|
||||
# idle-session-timeout: 30 # seconds
|
||||
# sni: "example.com"
|
||||
# alpn:
|
||||
# - h2
|
||||
# - http/1.1
|
||||
# skip-cert-verify: true
|
||||
|
||||
# dns 出站会将请求劫持到内部 dns 模块,所有请求均在内部处理
|
||||
- name: "dns-out"
|
||||
type: dns
|
||||
|
@ -1209,6 +1225,18 @@ listeners:
|
|||
- test.com
|
||||
### 注意,对于vless listener, 至少需要填写 “certificate和private-key” 或 “reality-config” 的其中一项 ###
|
||||
|
||||
- name: anytls-in-1
|
||||
type: anytls
|
||||
port: 10818
|
||||
listen: 0.0.0.0
|
||||
users:
|
||||
username1: password1
|
||||
username2: password2
|
||||
# "certificate" and "private-key" are required
|
||||
certificate: ./server.crt
|
||||
private-key: ./server.key
|
||||
# padding-scheme: "" # https://github.com/anytls/anytls-go/blob/main/docs/protocol.md#cmdupdatepaddingscheme
|
||||
|
||||
- name: tun-in-1
|
||||
type: tun
|
||||
# rule: sub-rule-name1 # 默认使用 rules,如果未找到 sub-rule 则直接使用 rules
|
||||
|
|
181
listener/anytls/server.go
Normal file
181
listener/anytls/server.go
Normal file
|
@ -0,0 +1,181 @@
|
|||
package anytls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/metacubex/mihomo/adapter/inbound"
|
||||
"github.com/metacubex/mihomo/common/buf"
|
||||
N "github.com/metacubex/mihomo/common/net"
|
||||
C "github.com/metacubex/mihomo/constant"
|
||||
LC "github.com/metacubex/mihomo/listener/config"
|
||||
"github.com/metacubex/mihomo/listener/sing"
|
||||
"github.com/metacubex/mihomo/transport/anytls/padding"
|
||||
"github.com/metacubex/mihomo/transport/anytls/session"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/auth"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
closed bool
|
||||
config LC.AnyTLSServer
|
||||
listeners []net.Listener
|
||||
tlsConfig *tls.Config
|
||||
userMap map[[32]byte]string
|
||||
padding atomic.TypedValue[*padding.PaddingFactory]
|
||||
}
|
||||
|
||||
func New(config LC.AnyTLSServer, tunnel C.Tunnel, additions ...inbound.Addition) (sl *Listener, err error) {
|
||||
if len(additions) == 0 {
|
||||
additions = []inbound.Addition{
|
||||
inbound.WithInName("DEFAULT-ANYTLS"),
|
||||
inbound.WithSpecialRules(""),
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
if config.Certificate != "" && config.PrivateKey != "" {
|
||||
cert, err := N.ParseCert(config.Certificate, config.PrivateKey, C.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
sl = &Listener{
|
||||
config: config,
|
||||
tlsConfig: tlsConfig,
|
||||
userMap: make(map[[32]byte]string),
|
||||
}
|
||||
|
||||
for user, password := range config.Users {
|
||||
sl.userMap[sha256.Sum256([]byte(password))] = user
|
||||
}
|
||||
|
||||
if len(config.PaddingScheme) > 0 {
|
||||
if !padding.UpdatePaddingScheme([]byte(config.PaddingScheme), &sl.padding) {
|
||||
return nil, errors.New("incorrect padding scheme format")
|
||||
}
|
||||
} else {
|
||||
padding.UpdatePaddingScheme(padding.DefaultPaddingScheme, &sl.padding)
|
||||
}
|
||||
|
||||
// Using sing handler can automatically handle UoT
|
||||
h, err := sing.NewListenerHandler(sing.ListenerConfig{
|
||||
Tunnel: tunnel,
|
||||
Type: C.ANYTLS,
|
||||
Additions: additions,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, addr := range strings.Split(config.Listen, ",") {
|
||||
addr := addr
|
||||
|
||||
//TCP
|
||||
l, err := inbound.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sl.listeners = append(sl.listeners, l)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
if sl.closed {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
go sl.HandleConn(c, h)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return sl, nil
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
l.closed = true
|
||||
var retErr error
|
||||
for _, lis := range l.listeners {
|
||||
err := lis.Close()
|
||||
if err != nil {
|
||||
retErr = err
|
||||
}
|
||||
}
|
||||
return retErr
|
||||
}
|
||||
|
||||
func (l *Listener) Config() string {
|
||||
return l.config.String()
|
||||
}
|
||||
|
||||
func (l *Listener) AddrList() (addrList []net.Addr) {
|
||||
for _, lis := range l.listeners {
|
||||
addrList = append(addrList, lis.Addr())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) HandleConn(conn net.Conn, h *sing.ListenerHandler) {
|
||||
ctx := context.TODO()
|
||||
|
||||
conn = tls.Server(conn, l.tlsConfig)
|
||||
defer conn.Close()
|
||||
|
||||
b := buf.NewPacket()
|
||||
_, err := b.ReadOnceFrom(conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn = bufio.NewCachedConn(conn, b)
|
||||
|
||||
by, err := b.ReadBytes(32)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var passwordSha256 [32]byte
|
||||
copy(passwordSha256[:], by)
|
||||
if user, ok := l.userMap[passwordSha256]; ok {
|
||||
ctx = auth.ContextWithUser(ctx, user)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
by, err = b.ReadBytes(2)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
paddingLen := binary.BigEndian.Uint16(by)
|
||||
if paddingLen > 0 {
|
||||
_, err = b.ReadBytes(int(paddingLen))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
session := session.NewServerSession(conn, func(stream *session.Stream) {
|
||||
defer stream.Close()
|
||||
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(stream)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.NewConnection(ctx, stream, M.Metadata{
|
||||
Source: M.SocksaddrFromNet(conn.RemoteAddr()),
|
||||
Destination: destination,
|
||||
})
|
||||
}, &l.padding)
|
||||
session.Run(true)
|
||||
session.Close()
|
||||
}
|
19
listener/config/anytls.go
Normal file
19
listener/config/anytls.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type AnyTLSServer struct {
|
||||
Enable bool `yaml:"enable" json:"enable"`
|
||||
Listen string `yaml:"listen" json:"listen"`
|
||||
Users map[string]string `yaml:"users" json:"users,omitempty"`
|
||||
Certificate string `yaml:"certificate" json:"certificate"`
|
||||
PrivateKey string `yaml:"private-key" json:"private-key"`
|
||||
PaddingScheme string `yaml:"padding-scheme" json:"padding-scheme,omitempty"`
|
||||
}
|
||||
|
||||
func (t AnyTLSServer) String() string {
|
||||
b, _ := json.Marshal(t)
|
||||
return string(b)
|
||||
}
|
79
listener/inbound/anytls.go
Normal file
79
listener/inbound/anytls.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package inbound
|
||||
|
||||
import (
|
||||
C "github.com/metacubex/mihomo/constant"
|
||||
"github.com/metacubex/mihomo/listener/anytls"
|
||||
LC "github.com/metacubex/mihomo/listener/config"
|
||||
"github.com/metacubex/mihomo/log"
|
||||
)
|
||||
|
||||
type AnyTLSOption struct {
|
||||
BaseOption
|
||||
Users map[string]string `inbound:"users,omitempty"`
|
||||
Certificate string `inbound:"certificate"`
|
||||
PrivateKey string `inbound:"private-key"`
|
||||
PaddingScheme string `inbound:"padding-scheme,omitempty"`
|
||||
}
|
||||
|
||||
func (o AnyTLSOption) Equal(config C.InboundConfig) bool {
|
||||
return optionToString(o) == optionToString(config)
|
||||
}
|
||||
|
||||
type AnyTLS struct {
|
||||
*Base
|
||||
config *AnyTLSOption
|
||||
l C.MultiAddrListener
|
||||
vs LC.AnyTLSServer
|
||||
}
|
||||
|
||||
func NewAnyTLS(options *AnyTLSOption) (*AnyTLS, error) {
|
||||
base, err := NewBase(&options.BaseOption)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AnyTLS{
|
||||
Base: base,
|
||||
config: options,
|
||||
vs: LC.AnyTLSServer{
|
||||
Enable: true,
|
||||
Listen: base.RawAddress(),
|
||||
Users: options.Users,
|
||||
Certificate: options.Certificate,
|
||||
PrivateKey: options.PrivateKey,
|
||||
PaddingScheme: options.PaddingScheme,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Config implements constant.InboundListener
|
||||
func (v *AnyTLS) Config() C.InboundConfig {
|
||||
return v.config
|
||||
}
|
||||
|
||||
// Address implements constant.InboundListener
|
||||
func (v *AnyTLS) Address() string {
|
||||
if v.l != nil {
|
||||
for _, addr := range v.l.AddrList() {
|
||||
return addr.String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Listen implements constant.InboundListener
|
||||
func (v *AnyTLS) Listen(tunnel C.Tunnel) error {
|
||||
var err error
|
||||
v.l, err = anytls.New(v.vs, tunnel, v.Additions()...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infoln("AnyTLS[%s] proxy listening at: %s", v.Name(), v.Address())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements constant.InboundListener
|
||||
func (v *AnyTLS) Close() error {
|
||||
return v.l.Close()
|
||||
}
|
||||
|
||||
var _ C.InboundListener = (*AnyTLS)(nil)
|
|
@ -113,6 +113,13 @@ func ParseListener(mapping map[string]any) (C.InboundListener, error) {
|
|||
return nil, err
|
||||
}
|
||||
listener, err = IN.NewTuic(tuicOption)
|
||||
case "anytls":
|
||||
anytlsOption := &IN.AnyTLSOption{}
|
||||
err = decoder.Decode(mapping, anytlsOption)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listener, err = IN.NewAnyTLS(anytlsOption)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupport proxy type: %s", proxyType)
|
||||
}
|
||||
|
|
123
transport/anytls/client.go
Normal file
123
transport/anytls/client.go
Normal file
|
@ -0,0 +1,123 @@
|
|||
package anytls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
tlsC "github.com/metacubex/mihomo/component/tls"
|
||||
C "github.com/metacubex/mihomo/constant"
|
||||
"github.com/metacubex/mihomo/transport/anytls/padding"
|
||||
"github.com/metacubex/mihomo/transport/anytls/session"
|
||||
"github.com/metacubex/mihomo/transport/vmess"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type ClientConfig struct {
|
||||
Password string
|
||||
IdleSessionCheckInterval time.Duration
|
||||
IdleSessionTimeout time.Duration
|
||||
Server M.Socksaddr
|
||||
Dialer N.Dialer
|
||||
TLSConfig *tls.Config
|
||||
ClientFingerprint string
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
passwordSha256 []byte
|
||||
tlsConfig *tls.Config
|
||||
clientFingerprint string
|
||||
dialer N.Dialer
|
||||
server M.Socksaddr
|
||||
sessionClient *session.Client
|
||||
padding atomic.TypedValue[*padding.PaddingFactory]
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, config ClientConfig) *Client {
|
||||
pw := sha256.Sum256([]byte(config.Password))
|
||||
c := &Client{
|
||||
passwordSha256: pw[:],
|
||||
tlsConfig: config.TLSConfig,
|
||||
clientFingerprint: config.ClientFingerprint,
|
||||
dialer: config.Dialer,
|
||||
server: config.Server,
|
||||
}
|
||||
// Initialize the padding state of this client
|
||||
padding.UpdatePaddingScheme(padding.DefaultPaddingScheme, &c.padding)
|
||||
c.sessionClient = session.NewClient(ctx, c.CreateOutboundTLSConnection, &c.padding, config.IdleSessionCheckInterval, config.IdleSessionTimeout)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) CreateProxy(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
|
||||
conn, err := c.sessionClient.CreateStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = M.SocksaddrSerializer.WriteAddrPort(conn, destination)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateOutboundTLSConnection(ctx context.Context) (net.Conn, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b := buf.NewPacket()
|
||||
b.Write(c.passwordSha256)
|
||||
var paddingLen int
|
||||
if pad := c.padding.Load().GenerateRecordPayloadSizes(0); len(pad) > 0 {
|
||||
paddingLen = pad[0]
|
||||
}
|
||||
binary.BigEndian.PutUint16(b.Extend(2), uint16(paddingLen))
|
||||
if paddingLen > 0 {
|
||||
b.WriteZeroN(paddingLen)
|
||||
}
|
||||
|
||||
getTlsConn := func() (net.Conn, error) {
|
||||
if len(c.clientFingerprint) != 0 {
|
||||
utlsConn, valid := vmess.GetUTLSConn(conn, c.clientFingerprint, c.tlsConfig)
|
||||
if valid {
|
||||
ctx, cancel := context.WithTimeout(ctx, C.DefaultTLSTimeout)
|
||||
defer cancel()
|
||||
|
||||
err := utlsConn.(*tlsC.UConn).HandshakeContext(ctx)
|
||||
return utlsConn, err
|
||||
}
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(conn, c.tlsConfig)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
|
||||
defer cancel()
|
||||
|
||||
err = tlsConn.HandshakeContext(ctx)
|
||||
return tlsConn, err
|
||||
}
|
||||
tlsConn, err := getTlsConn()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = b.WriteTo(tlsConn)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
func (h *Client) Close() error {
|
||||
return h.sessionClient.Close()
|
||||
}
|
92
transport/anytls/padding/padding.go
Normal file
92
transport/anytls/padding/padding.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package padding
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/metacubex/mihomo/transport/anytls/util"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
)
|
||||
|
||||
const CheckMark = -1
|
||||
|
||||
var DefaultPaddingScheme = []byte(`stop=8
|
||||
0=34-120
|
||||
1=100-400
|
||||
2=400-500,c,500-1000,c,400-500,c,500-1000,c,500-1000,c,400-500
|
||||
3=500-1000
|
||||
4=500-1000
|
||||
5=500-1000
|
||||
6=500-1000
|
||||
7=500-1000`)
|
||||
|
||||
type PaddingFactory struct {
|
||||
scheme util.StringMap
|
||||
RawScheme []byte
|
||||
Stop uint32
|
||||
Md5 string
|
||||
}
|
||||
|
||||
func UpdatePaddingScheme(rawScheme []byte, to *atomic.TypedValue[*PaddingFactory]) bool {
|
||||
if p := NewPaddingFactory(rawScheme); p != nil {
|
||||
to.Store(p)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func NewPaddingFactory(rawScheme []byte) *PaddingFactory {
|
||||
p := &PaddingFactory{
|
||||
RawScheme: rawScheme,
|
||||
Md5: fmt.Sprintf("%x", md5.Sum(rawScheme)),
|
||||
}
|
||||
scheme := util.StringMapFromBytes(rawScheme)
|
||||
if len(scheme) == 0 {
|
||||
return nil
|
||||
}
|
||||
if stop, err := strconv.Atoi(scheme["stop"]); err == nil {
|
||||
p.Stop = uint32(stop)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
p.scheme = scheme
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PaddingFactory) GenerateRecordPayloadSizes(pkt uint32) (pktSizes []int) {
|
||||
if s, ok := p.scheme[strconv.Itoa(int(pkt))]; ok {
|
||||
sRanges := strings.Split(s, ",")
|
||||
for _, sRange := range sRanges {
|
||||
sRangeMinMax := strings.Split(sRange, "-")
|
||||
if len(sRangeMinMax) == 2 {
|
||||
_min, err := strconv.ParseInt(sRangeMinMax[0], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_max, err := strconv.ParseInt(sRangeMinMax[1], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_min, _max = common.Min(_min, _max), common.Max(_min, _max)
|
||||
if _min <= 0 || _max <= 0 {
|
||||
continue
|
||||
}
|
||||
if _min == _max {
|
||||
pktSizes = append(pktSizes, int(_min))
|
||||
} else {
|
||||
i, _ := rand.Int(rand.Reader, big.NewInt(_max-_min))
|
||||
pktSizes = append(pktSizes, int(i.Int64()+_min))
|
||||
}
|
||||
} else if sRange == "c" {
|
||||
pktSizes = append(pktSizes, CheckMark)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
160
transport/anytls/session/client.go
Normal file
160
transport/anytls/session/client.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/transport/anytls/padding"
|
||||
"github.com/metacubex/mihomo/transport/anytls/skiplist"
|
||||
"github.com/metacubex/mihomo/transport/anytls/util"
|
||||
"github.com/sagernet/sing/common"
|
||||
singAtomic "github.com/sagernet/sing/common/atomic"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
die context.Context
|
||||
dieCancel context.CancelFunc
|
||||
|
||||
dialOut func(ctx context.Context) (net.Conn, error)
|
||||
|
||||
sessionCounter atomic.Uint64
|
||||
idleSession *skiplist.SkipList[uint64, *Session]
|
||||
idleSessionLock sync.Mutex
|
||||
|
||||
padding *singAtomic.TypedValue[*padding.PaddingFactory]
|
||||
|
||||
idleSessionTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error), _padding *singAtomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration) *Client {
|
||||
c := &Client{
|
||||
dialOut: dialOut,
|
||||
padding: _padding,
|
||||
idleSessionTimeout: idleSessionTimeout,
|
||||
}
|
||||
if idleSessionCheckInterval <= time.Second*5 {
|
||||
idleSessionCheckInterval = time.Second * 30
|
||||
}
|
||||
if c.idleSessionTimeout <= time.Second*5 {
|
||||
c.idleSessionTimeout = time.Second * 30
|
||||
}
|
||||
c.die, c.dieCancel = context.WithCancel(ctx)
|
||||
c.idleSession = skiplist.NewSkipList[uint64, *Session]()
|
||||
util.StartRoutine(c.die, idleSessionCheckInterval, c.idleCleanup)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) {
|
||||
select {
|
||||
case <-c.die.Done():
|
||||
return nil, io.ErrClosedPipe
|
||||
default:
|
||||
}
|
||||
|
||||
var session *Session
|
||||
var stream *Stream
|
||||
var err error
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
session, err = c.findSession(ctx)
|
||||
if session == nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
stream, err = session.OpenStream()
|
||||
if err != nil {
|
||||
common.Close(session, stream)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
if session == nil || stream == nil {
|
||||
return nil, fmt.Errorf("too many closed session: %w", err)
|
||||
}
|
||||
|
||||
stream.dieHook = func() {
|
||||
if session.IsClosed() {
|
||||
if session.dieHook != nil {
|
||||
session.dieHook()
|
||||
}
|
||||
} else {
|
||||
c.idleSessionLock.Lock()
|
||||
session.idleSince = time.Now()
|
||||
c.idleSession.Insert(math.MaxUint64-session.seq, session)
|
||||
c.idleSessionLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *Client) findSession(ctx context.Context) (*Session, error) {
|
||||
var idle *Session
|
||||
|
||||
c.idleSessionLock.Lock()
|
||||
if !c.idleSession.IsEmpty() {
|
||||
it := c.idleSession.Iterate()
|
||||
idle = it.Value()
|
||||
c.idleSession.Remove(it.Key())
|
||||
}
|
||||
c.idleSessionLock.Unlock()
|
||||
|
||||
if idle == nil {
|
||||
s, err := c.createSession(ctx)
|
||||
return s, err
|
||||
}
|
||||
return idle, nil
|
||||
}
|
||||
|
||||
func (c *Client) createSession(ctx context.Context) (*Session, error) {
|
||||
underlying, err := c.dialOut(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session := NewClientSession(underlying, c.padding)
|
||||
session.seq = c.sessionCounter.Add(1)
|
||||
session.dieHook = func() {
|
||||
//logrus.Debugln("session died", session)
|
||||
c.idleSessionLock.Lock()
|
||||
c.idleSession.Remove(math.MaxUint64 - session.seq)
|
||||
c.idleSessionLock.Unlock()
|
||||
}
|
||||
session.Run(false)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.dieCancel()
|
||||
go c.idleCleanupExpTime(time.Now())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) idleCleanup() {
|
||||
c.idleCleanupExpTime(time.Now().Add(-c.idleSessionTimeout))
|
||||
}
|
||||
|
||||
func (c *Client) idleCleanupExpTime(expTime time.Time) {
|
||||
var sessionToRemove = make([]*Session, 0)
|
||||
|
||||
c.idleSessionLock.Lock()
|
||||
it := c.idleSession.Iterate()
|
||||
for it.IsNotEnd() {
|
||||
session := it.Value()
|
||||
if session.idleSince.Before(expTime) {
|
||||
sessionToRemove = append(sessionToRemove, session)
|
||||
c.idleSession.Remove(it.Key())
|
||||
}
|
||||
it.MoveToNext()
|
||||
}
|
||||
c.idleSessionLock.Unlock()
|
||||
|
||||
for _, session := range sessionToRemove {
|
||||
session.Close()
|
||||
}
|
||||
}
|
44
transport/anytls/session/frame.go
Normal file
44
transport/anytls/session/frame.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
const ( // cmds
|
||||
cmdWaste = 0 // Paddings
|
||||
cmdSYN = 1 // stream open
|
||||
cmdPSH = 2 // data push
|
||||
cmdFIN = 3 // stream close, a.k.a EOF mark
|
||||
cmdSettings = 4 // Settings
|
||||
cmdAlert = 5 // Alert
|
||||
cmdUpdatePaddingScheme = 6 // update padding scheme
|
||||
)
|
||||
|
||||
const (
|
||||
headerOverHeadSize = 1 + 4 + 2
|
||||
)
|
||||
|
||||
// frame defines a packet from or to be multiplexed into a single connection
|
||||
type frame struct {
|
||||
cmd byte // 1
|
||||
sid uint32 // 4
|
||||
data []byte // 2 + len(data)
|
||||
}
|
||||
|
||||
func newFrame(cmd byte, sid uint32) frame {
|
||||
return frame{cmd: cmd, sid: sid}
|
||||
}
|
||||
|
||||
type rawHeader [headerOverHeadSize]byte
|
||||
|
||||
func (h rawHeader) Cmd() byte {
|
||||
return h[0]
|
||||
}
|
||||
|
||||
func (h rawHeader) StreamID() uint32 {
|
||||
return binary.BigEndian.Uint32(h[1:])
|
||||
}
|
||||
|
||||
func (h rawHeader) Length() uint16 {
|
||||
return binary.BigEndian.Uint16(h[5:])
|
||||
}
|
379
transport/anytls/session/session.go
Normal file
379
transport/anytls/session/session.go
Normal file
|
@ -0,0 +1,379 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/constant"
|
||||
"github.com/metacubex/mihomo/log"
|
||||
"github.com/metacubex/mihomo/transport/anytls/padding"
|
||||
"github.com/metacubex/mihomo/transport/anytls/util"
|
||||
singAtomic "github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
conn net.Conn
|
||||
connLock sync.Mutex
|
||||
|
||||
streams map[uint32]*Stream
|
||||
streamId atomic.Uint32
|
||||
streamLock sync.RWMutex
|
||||
|
||||
dieOnce sync.Once
|
||||
die chan struct{}
|
||||
dieHook func()
|
||||
|
||||
// pool
|
||||
seq uint64
|
||||
idleSince time.Time
|
||||
padding *singAtomic.TypedValue[*padding.PaddingFactory]
|
||||
|
||||
// client
|
||||
isClient bool
|
||||
buffering bool
|
||||
buffer []byte
|
||||
pktCounter atomic.Uint32
|
||||
|
||||
// server
|
||||
onNewStream func(stream *Stream)
|
||||
}
|
||||
|
||||
func NewClientSession(conn net.Conn, _padding *singAtomic.TypedValue[*padding.PaddingFactory]) *Session {
|
||||
s := &Session{
|
||||
conn: conn,
|
||||
isClient: true,
|
||||
padding: _padding,
|
||||
}
|
||||
s.die = make(chan struct{})
|
||||
s.streams = make(map[uint32]*Stream)
|
||||
return s
|
||||
}
|
||||
|
||||
func NewServerSession(conn net.Conn, onNewStream func(stream *Stream), _padding *singAtomic.TypedValue[*padding.PaddingFactory]) *Session {
|
||||
s := &Session{
|
||||
conn: conn,
|
||||
onNewStream: onNewStream,
|
||||
isClient: false,
|
||||
padding: _padding,
|
||||
}
|
||||
s.die = make(chan struct{})
|
||||
s.streams = make(map[uint32]*Stream)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Session) Run(isServer bool) {
|
||||
if isServer {
|
||||
s.recvLoop()
|
||||
return
|
||||
}
|
||||
|
||||
settings := util.StringMap{
|
||||
"v": "1",
|
||||
"client": "mihomo/" + constant.Version,
|
||||
"padding-md5": s.padding.Load().Md5,
|
||||
}
|
||||
f := newFrame(cmdSettings, 0)
|
||||
f.data = settings.ToBytes()
|
||||
s.buffering = true
|
||||
s.writeFrame(f)
|
||||
|
||||
go s.recvLoop()
|
||||
}
|
||||
|
||||
// IsClosed does a safe check to see if we have shutdown
|
||||
func (s *Session) IsClosed() bool {
|
||||
select {
|
||||
case <-s.die:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Close is used to close the session and all streams.
|
||||
func (s *Session) Close() error {
|
||||
var once bool
|
||||
s.dieOnce.Do(func() {
|
||||
close(s.die)
|
||||
once = true
|
||||
})
|
||||
|
||||
if once {
|
||||
if s.dieHook != nil {
|
||||
s.dieHook()
|
||||
}
|
||||
s.streamLock.Lock()
|
||||
for k := range s.streams {
|
||||
s.streams[k].sessionClose()
|
||||
}
|
||||
s.streamLock.Unlock()
|
||||
return s.conn.Close()
|
||||
} else {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
// OpenStream is used to create a new stream for CLIENT
|
||||
func (s *Session) OpenStream() (*Stream, error) {
|
||||
if s.IsClosed() {
|
||||
return nil, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
sid := s.streamId.Add(1)
|
||||
stream := newStream(sid, s)
|
||||
|
||||
//logrus.Debugln("stream open", sid, s.streams)
|
||||
|
||||
if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.buffering = false // proxy Write it's SocksAddr to flush the buffer
|
||||
|
||||
s.streamLock.Lock()
|
||||
defer s.streamLock.Unlock()
|
||||
select {
|
||||
case <-s.die:
|
||||
return nil, io.ErrClosedPipe
|
||||
default:
|
||||
s.streams[sid] = stream
|
||||
return stream, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) recvLoop() error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorln("[BUG] %v %s", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
defer s.Close()
|
||||
|
||||
var receivedSettingsFromClient bool
|
||||
var hdr rawHeader
|
||||
|
||||
for {
|
||||
if s.IsClosed() {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
// read header first
|
||||
if _, err := io.ReadFull(s.conn, hdr[:]); err == nil {
|
||||
sid := hdr.StreamID()
|
||||
switch hdr.Cmd() {
|
||||
case cmdPSH:
|
||||
if hdr.Length() > 0 {
|
||||
buffer := buf.Get(int(hdr.Length()))
|
||||
if _, err := io.ReadFull(s.conn, buffer); err == nil {
|
||||
s.streamLock.RLock()
|
||||
stream, ok := s.streams[sid]
|
||||
s.streamLock.RUnlock()
|
||||
if ok {
|
||||
stream.pipeW.Write(buffer)
|
||||
}
|
||||
buf.Put(buffer)
|
||||
} else {
|
||||
buf.Put(buffer)
|
||||
return err
|
||||
}
|
||||
}
|
||||
case cmdSYN: // should be server only
|
||||
if !s.isClient && !receivedSettingsFromClient {
|
||||
f := newFrame(cmdAlert, 0)
|
||||
f.data = []byte("client did not send its settings")
|
||||
s.writeFrame(f)
|
||||
return nil
|
||||
}
|
||||
s.streamLock.Lock()
|
||||
if _, ok := s.streams[sid]; !ok {
|
||||
stream := newStream(sid, s)
|
||||
s.streams[sid] = stream
|
||||
if s.onNewStream != nil {
|
||||
go s.onNewStream(stream)
|
||||
} else {
|
||||
go s.Close()
|
||||
}
|
||||
}
|
||||
s.streamLock.Unlock()
|
||||
case cmdFIN:
|
||||
s.streamLock.RLock()
|
||||
stream, ok := s.streams[sid]
|
||||
s.streamLock.RUnlock()
|
||||
if ok {
|
||||
stream.Close()
|
||||
}
|
||||
//logrus.Debugln("stream fin", sid, s.streams)
|
||||
case cmdWaste:
|
||||
if hdr.Length() > 0 {
|
||||
buffer := buf.Get(int(hdr.Length()))
|
||||
if _, err := io.ReadFull(s.conn, buffer); err != nil {
|
||||
buf.Put(buffer)
|
||||
return err
|
||||
}
|
||||
buf.Put(buffer)
|
||||
}
|
||||
case cmdSettings:
|
||||
if hdr.Length() > 0 {
|
||||
buffer := buf.Get(int(hdr.Length()))
|
||||
if _, err := io.ReadFull(s.conn, buffer); err != nil {
|
||||
buf.Put(buffer)
|
||||
return err
|
||||
}
|
||||
if !s.isClient {
|
||||
receivedSettingsFromClient = true
|
||||
m := util.StringMapFromBytes(buffer)
|
||||
paddingF := s.padding.Load()
|
||||
if m["padding-md5"] != paddingF.Md5 {
|
||||
// logrus.Debugln("remote md5 is", m["padding-md5"])
|
||||
f := newFrame(cmdUpdatePaddingScheme, 0)
|
||||
f.data = paddingF.RawScheme
|
||||
_, err = s.writeFrame(f)
|
||||
if err != nil {
|
||||
buf.Put(buffer)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
buf.Put(buffer)
|
||||
}
|
||||
case cmdAlert:
|
||||
if hdr.Length() > 0 {
|
||||
buffer := buf.Get(int(hdr.Length()))
|
||||
if _, err := io.ReadFull(s.conn, buffer); err != nil {
|
||||
buf.Put(buffer)
|
||||
return err
|
||||
}
|
||||
if s.isClient {
|
||||
log.Errorln("[Alert from server] %s", string(buffer))
|
||||
}
|
||||
buf.Put(buffer)
|
||||
return nil
|
||||
}
|
||||
case cmdUpdatePaddingScheme:
|
||||
if hdr.Length() > 0 {
|
||||
buffer := buf.Get(int(hdr.Length()))
|
||||
if _, err := io.ReadFull(s.conn, buffer); err != nil {
|
||||
buf.Put(buffer)
|
||||
return err
|
||||
}
|
||||
if s.isClient {
|
||||
if padding.UpdatePaddingScheme(buffer, s.padding) {
|
||||
log.Infoln("[Update padding succeed] %x\n", md5.Sum(buffer))
|
||||
} else {
|
||||
log.Warnln("[Update padding failed] %x\n", md5.Sum(buffer))
|
||||
}
|
||||
}
|
||||
buf.Put(buffer)
|
||||
}
|
||||
default:
|
||||
// I don't know what command it is (can't have data)
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// notify the session that a stream has closed
|
||||
func (s *Session) streamClosed(sid uint32) error {
|
||||
_, err := s.writeFrame(newFrame(cmdFIN, sid))
|
||||
s.streamLock.Lock()
|
||||
delete(s.streams, sid)
|
||||
s.streamLock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Session) writeFrame(frame frame) (int, error) {
|
||||
dataLen := len(frame.data)
|
||||
|
||||
buffer := buf.NewSize(dataLen + headerOverHeadSize)
|
||||
buffer.WriteByte(frame.cmd)
|
||||
binary.BigEndian.PutUint32(buffer.Extend(4), frame.sid)
|
||||
binary.BigEndian.PutUint16(buffer.Extend(2), uint16(dataLen))
|
||||
buffer.Write(frame.data)
|
||||
_, err := s.writeConn(buffer.Bytes())
|
||||
buffer.Release()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return dataLen, nil
|
||||
}
|
||||
|
||||
func (s *Session) writeConn(b []byte) (n int, err error) {
|
||||
s.connLock.Lock()
|
||||
defer s.connLock.Unlock()
|
||||
|
||||
if s.buffering {
|
||||
s.buffer = append(s.buffer, b...)
|
||||
return len(b), nil
|
||||
} else if len(s.buffer) > 0 {
|
||||
b = append(s.buffer, b...)
|
||||
s.buffer = nil
|
||||
}
|
||||
|
||||
// calulate & send padding
|
||||
if s.isClient {
|
||||
pkt := s.pktCounter.Add(1)
|
||||
paddingF := s.padding.Load()
|
||||
if pkt < paddingF.Stop {
|
||||
pktSizes := paddingF.GenerateRecordPayloadSizes(pkt)
|
||||
for _, l := range pktSizes {
|
||||
remainPayloadLen := len(b)
|
||||
if l == padding.CheckMark {
|
||||
if remainPayloadLen == 0 {
|
||||
break
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// logrus.Debugln(pkt, "write", l, "len", remainPayloadLen, "remain", remainPayloadLen-l)
|
||||
if remainPayloadLen > l { // this packet is all payload
|
||||
_, err = s.conn.Write(b[:l])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n += l
|
||||
b = b[l:]
|
||||
} else if remainPayloadLen > 0 { // this packet contains padding and the last part of payload
|
||||
paddingLen := l - remainPayloadLen
|
||||
if paddingLen > 0 {
|
||||
padding := make([]byte, headerOverHeadSize+paddingLen)
|
||||
padding[0] = cmdWaste
|
||||
binary.BigEndian.PutUint32(padding[1:5], 0)
|
||||
binary.BigEndian.PutUint16(padding[5:7], uint16(paddingLen))
|
||||
b = append(b, padding...)
|
||||
}
|
||||
_, err = s.conn.Write(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n += remainPayloadLen
|
||||
b = nil
|
||||
} else { // this packet is all padding
|
||||
padding := make([]byte, headerOverHeadSize+l)
|
||||
padding[0] = cmdWaste
|
||||
binary.BigEndian.PutUint32(padding[1:5], 0)
|
||||
binary.BigEndian.PutUint16(padding[5:7], uint16(l))
|
||||
_, err = s.conn.Write(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b = nil
|
||||
}
|
||||
}
|
||||
// maybe still remain payload to write
|
||||
if len(b) == 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.conn.Write(b)
|
||||
}
|
99
transport/anytls/session/stream.go
Normal file
99
transport/anytls/session/stream.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Stream implements net.Conn
|
||||
type Stream struct {
|
||||
id uint32
|
||||
|
||||
sess *Session
|
||||
|
||||
pipeR *io.PipeReader
|
||||
pipeW *io.PipeWriter
|
||||
|
||||
dieOnce sync.Once
|
||||
dieHook func()
|
||||
}
|
||||
|
||||
// newStream initiates a Stream struct
|
||||
func newStream(id uint32, sess *Session) *Stream {
|
||||
s := new(Stream)
|
||||
s.id = id
|
||||
s.sess = sess
|
||||
s.pipeR, s.pipeW = io.Pipe()
|
||||
return s
|
||||
}
|
||||
|
||||
// Read implements net.Conn
|
||||
func (s *Stream) Read(b []byte) (n int, err error) {
|
||||
return s.pipeR.Read(b)
|
||||
}
|
||||
|
||||
// Write implements net.Conn
|
||||
func (s *Stream) Write(b []byte) (n int, err error) {
|
||||
f := newFrame(cmdPSH, s.id)
|
||||
f.data = b
|
||||
n, err = s.sess.writeFrame(f)
|
||||
return
|
||||
}
|
||||
|
||||
// Close implements net.Conn
|
||||
func (s *Stream) Close() error {
|
||||
if s.sessionClose() {
|
||||
// notify remote
|
||||
return s.sess.streamClosed(s.id)
|
||||
} else {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
// sessionClose close stream from session side, do not notify remote
|
||||
func (s *Stream) sessionClose() (once bool) {
|
||||
s.dieOnce.Do(func() {
|
||||
s.pipeR.Close()
|
||||
once = true
|
||||
if s.dieHook != nil {
|
||||
s.dieHook()
|
||||
s.dieHook = nil
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Stream) SetReadDeadline(t time.Time) error {
|
||||
return os.ErrNotExist
|
||||
}
|
||||
|
||||
func (s *Stream) SetWriteDeadline(t time.Time) error {
|
||||
return os.ErrNotExist
|
||||
}
|
||||
|
||||
func (s *Stream) SetDeadline(t time.Time) error {
|
||||
return os.ErrNotExist
|
||||
}
|
||||
|
||||
// LocalAddr satisfies net.Conn interface
|
||||
func (s *Stream) LocalAddr() net.Addr {
|
||||
if ts, ok := s.sess.conn.(interface {
|
||||
LocalAddr() net.Addr
|
||||
}); ok {
|
||||
return ts.LocalAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoteAddr satisfies net.Conn interface
|
||||
func (s *Stream) RemoteAddr() net.Addr {
|
||||
if ts, ok := s.sess.conn.(interface {
|
||||
RemoteAddr() net.Addr
|
||||
}); ok {
|
||||
return ts.RemoteAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
46
transport/anytls/skiplist/contianer.go
Normal file
46
transport/anytls/skiplist/contianer.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package skiplist
|
||||
|
||||
// Container is a holder object that stores a collection of other objects.
|
||||
type Container interface {
|
||||
IsEmpty() bool // IsEmpty checks if the container has no elements.
|
||||
Len() int // Len returns the number of elements in the container.
|
||||
Clear() // Clear erases all elements from the container. After this call, Len() returns zero.
|
||||
}
|
||||
|
||||
// Map is a associative container that contains key-value pairs with unique keys.
|
||||
type Map[K any, V any] interface {
|
||||
Container
|
||||
Has(K) bool // Checks whether the container contains element with specific key.
|
||||
Find(K) *V // Finds element with specific key.
|
||||
Insert(K, V) // Inserts a key-value pair in to the container or replace existing value.
|
||||
Remove(K) bool // Remove element with specific key.
|
||||
ForEach(func(K, V)) // Iterate the container.
|
||||
ForEachIf(func(K, V) bool) // Iterate the container, stops when the callback returns false.
|
||||
ForEachMutable(func(K, *V)) // Iterate the container, *V is mutable.
|
||||
ForEachMutableIf(func(K, *V) bool) // Iterate the container, *V is mutable, stops when the callback returns false.
|
||||
}
|
||||
|
||||
// Set is a containers that store unique elements.
|
||||
type Set[K any] interface {
|
||||
Container
|
||||
Has(K) bool // Checks whether the container contains element with specific key.
|
||||
Insert(K) // Inserts a key-value pair in to the container or replace existing value.
|
||||
InsertN(...K) // Inserts multiple key-value pairs in to the container or replace existing value.
|
||||
Remove(K) bool // Remove element with specific key.
|
||||
RemoveN(...K) // Remove multiple elements with specific keys.
|
||||
ForEach(func(K)) // Iterate the container.
|
||||
ForEachIf(func(K) bool) // Iterate the container, stops when the callback returns false.
|
||||
}
|
||||
|
||||
// Iterator is the interface for container's iterator.
|
||||
type Iterator[T any] interface {
|
||||
IsNotEnd() bool // Whether it is point to the end of the range.
|
||||
MoveToNext() // Let it point to the next element.
|
||||
Value() T // Return the value of current element.
|
||||
}
|
||||
|
||||
// MapIterator is the interface for map's iterator.
|
||||
type MapIterator[K any, V any] interface {
|
||||
Iterator[V]
|
||||
Key() K // The key of the element
|
||||
}
|
455
transport/anytls/skiplist/skiplist.go
Normal file
455
transport/anytls/skiplist/skiplist.go
Normal file
|
@ -0,0 +1,455 @@
|
|||
package skiplist
|
||||
|
||||
// This implementation is based on https://github.com/liyue201/gostl/tree/master/ds/skiplist
|
||||
// (many thanks), added many optimizations, such as:
|
||||
//
|
||||
// - adaptive level
|
||||
// - lesser search for prevs when key already exists.
|
||||
// - reduce memory allocations
|
||||
// - richer interface.
|
||||
//
|
||||
// etc.
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
const (
|
||||
skipListMaxLevel = 40
|
||||
)
|
||||
|
||||
// SkipList is a probabilistic data structure that seem likely to supplant balanced trees as the
|
||||
// implementation method of choice for many applications. Skip list algorithms have the same
|
||||
// asymptotic expected time bounds as balanced trees and are simpler, faster and use less space.
|
||||
//
|
||||
// See https://en.wikipedia.org/wiki/Skip_list for more details.
|
||||
type SkipList[K any, V any] struct {
|
||||
level int // Current level, may increase dynamically during insertion
|
||||
len int // Total elements numner in the skiplist.
|
||||
head skipListNode[K, V] // head.next[level] is the head of each level.
|
||||
// This cache is used to save the previous nodes when modifying the skip list to avoid
|
||||
// allocating memory each time it is called.
|
||||
prevsCache []*skipListNode[K, V]
|
||||
rander *rand.Rand
|
||||
impl skipListImpl[K, V]
|
||||
}
|
||||
|
||||
// NewSkipList creates a new SkipList for Ordered key type.
|
||||
func NewSkipList[K Ordered, V any]() *SkipList[K, V] {
|
||||
sl := skipListOrdered[K, V]{}
|
||||
sl.init()
|
||||
sl.impl = (skipListImpl[K, V])(&sl)
|
||||
return &sl.SkipList
|
||||
}
|
||||
|
||||
// NewSkipListFromMap creates a new SkipList from a map.
|
||||
func NewSkipListFromMap[K Ordered, V any](m map[K]V) *SkipList[K, V] {
|
||||
sl := NewSkipList[K, V]()
|
||||
for k, v := range m {
|
||||
sl.Insert(k, v)
|
||||
}
|
||||
return sl
|
||||
}
|
||||
|
||||
// NewSkipListFunc creates a new SkipList with specified compare function keyCmp.
|
||||
func NewSkipListFunc[K any, V any](keyCmp CompareFn[K]) *SkipList[K, V] {
|
||||
sl := skipListFunc[K, V]{}
|
||||
sl.init()
|
||||
sl.keyCmp = keyCmp
|
||||
sl.impl = skipListImpl[K, V](&sl)
|
||||
return &sl.SkipList
|
||||
}
|
||||
|
||||
// IsEmpty implements the Container interface.
|
||||
func (sl *SkipList[K, V]) IsEmpty() bool {
|
||||
return sl.len == 0
|
||||
}
|
||||
|
||||
// Len implements the Container interface.
|
||||
func (sl *SkipList[K, V]) Len() int {
|
||||
return sl.len
|
||||
}
|
||||
|
||||
// Clear implements the Container interface.
|
||||
func (sl *SkipList[K, V]) Clear() {
|
||||
for i := range sl.head.next {
|
||||
sl.head.next[i] = nil
|
||||
}
|
||||
sl.level = 1
|
||||
sl.len = 0
|
||||
}
|
||||
|
||||
// Iterate return an iterator to the skiplist.
|
||||
func (sl *SkipList[K, V]) Iterate() MapIterator[K, V] {
|
||||
return &skipListIterator[K, V]{sl.head.next[0], nil}
|
||||
}
|
||||
|
||||
// Insert inserts a key-value pair into the skiplist.
|
||||
// If the key is already in the skip list, it's value will be updated.
|
||||
func (sl *SkipList[K, V]) Insert(key K, value V) {
|
||||
node, prevs := sl.impl.findInsertPoint(key)
|
||||
|
||||
if node != nil {
|
||||
// Already exist, update the value
|
||||
node.value = value
|
||||
return
|
||||
}
|
||||
|
||||
level := sl.randomLevel()
|
||||
node = newSkipListNode(level, key, value)
|
||||
|
||||
for i := 0; i < common.Min(level, sl.level); i++ {
|
||||
node.next[i] = prevs[i].next[i]
|
||||
prevs[i].next[i] = node
|
||||
}
|
||||
|
||||
if level > sl.level {
|
||||
for i := sl.level; i < level; i++ {
|
||||
sl.head.next[i] = node
|
||||
}
|
||||
sl.level = level
|
||||
}
|
||||
|
||||
sl.len++
|
||||
}
|
||||
|
||||
// Find returns the value associated with the passed key if the key is in the skiplist, otherwise
|
||||
// returns nil.
|
||||
func (sl *SkipList[K, V]) Find(key K) *V {
|
||||
node := sl.impl.findNode(key)
|
||||
if node != nil {
|
||||
return &node.value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Has implement the Map interface.
|
||||
func (sl *SkipList[K, V]) Has(key K) bool {
|
||||
return sl.impl.findNode(key) != nil
|
||||
}
|
||||
|
||||
// LowerBound returns an iterator to the first element in the skiplist that
|
||||
// does not satisfy element < value (i.e. greater or equal to),
|
||||
// or a end itetator if no such element is found.
|
||||
func (sl *SkipList[K, V]) LowerBound(key K) MapIterator[K, V] {
|
||||
return &skipListIterator[K, V]{sl.impl.lowerBound(key), nil}
|
||||
}
|
||||
|
||||
// UpperBound returns an iterator to the first element in the skiplist that
|
||||
// does not satisfy value < element (i.e. strictly greater),
|
||||
// or a end itetator if no such element is found.
|
||||
func (sl *SkipList[K, V]) UpperBound(key K) MapIterator[K, V] {
|
||||
return &skipListIterator[K, V]{sl.impl.upperBound(key), nil}
|
||||
}
|
||||
|
||||
// FindRange returns an iterator in range [first, last) (last is not includeed).
|
||||
func (sl *SkipList[K, V]) FindRange(first, last K) MapIterator[K, V] {
|
||||
return &skipListIterator[K, V]{sl.impl.lowerBound(first), sl.impl.upperBound(last)}
|
||||
}
|
||||
|
||||
// Remove removes the key-value pair associated with the passed key and returns true if the key is
|
||||
// in the skiplist, otherwise returns false.
|
||||
func (sl *SkipList[K, V]) Remove(key K) bool {
|
||||
node, prevs := sl.impl.findRemovePoint(key)
|
||||
if node == nil {
|
||||
return false
|
||||
}
|
||||
for i, v := range node.next {
|
||||
prevs[i].next[i] = v
|
||||
}
|
||||
for sl.level > 1 && sl.head.next[sl.level-1] == nil {
|
||||
sl.level--
|
||||
}
|
||||
sl.len--
|
||||
return true
|
||||
}
|
||||
|
||||
// ForEach implements the Map interface.
|
||||
func (sl *SkipList[K, V]) ForEach(op func(K, V)) {
|
||||
for e := sl.head.next[0]; e != nil; e = e.next[0] {
|
||||
op(e.key, e.value)
|
||||
}
|
||||
}
|
||||
|
||||
// ForEachMutable implements the Map interface.
|
||||
func (sl *SkipList[K, V]) ForEachMutable(op func(K, *V)) {
|
||||
for e := sl.head.next[0]; e != nil; e = e.next[0] {
|
||||
op(e.key, &e.value)
|
||||
}
|
||||
}
|
||||
|
||||
// ForEachIf implements the Map interface.
|
||||
func (sl *SkipList[K, V]) ForEachIf(op func(K, V) bool) {
|
||||
for e := sl.head.next[0]; e != nil; e = e.next[0] {
|
||||
if !op(e.key, e.value) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ForEachMutableIf implements the Map interface.
|
||||
func (sl *SkipList[K, V]) ForEachMutableIf(op func(K, *V) bool) {
|
||||
for e := sl.head.next[0]; e != nil; e = e.next[0] {
|
||||
if !op(e.key, &e.value) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// SkipList implementation part.
|
||||
|
||||
type skipListNode[K any, V any] struct {
|
||||
key K
|
||||
value V
|
||||
next []*skipListNode[K, V]
|
||||
}
|
||||
|
||||
//go:generate bash ./skiplist_newnode_generate.sh skipListMaxLevel skiplist_newnode.go
|
||||
// func newSkipListNode[K Ordered, V any](level int, key K, value V) *skipListNode[K, V]
|
||||
|
||||
type skipListIterator[K any, V any] struct {
|
||||
node, end *skipListNode[K, V]
|
||||
}
|
||||
|
||||
func (it *skipListIterator[K, V]) IsNotEnd() bool {
|
||||
return it.node != it.end
|
||||
}
|
||||
|
||||
func (it *skipListIterator[K, V]) MoveToNext() {
|
||||
it.node = it.node.next[0]
|
||||
}
|
||||
|
||||
func (it *skipListIterator[K, V]) Key() K {
|
||||
return it.node.key
|
||||
}
|
||||
|
||||
func (it *skipListIterator[K, V]) Value() V {
|
||||
return it.node.value
|
||||
}
|
||||
|
||||
// skipListImpl is an interface to provide different implementation for Ordered key or CompareFn.
|
||||
//
|
||||
// We can use CompareFn to cumpare Ordered keys, but a separated implementation is much faster.
|
||||
// We don't make the whole skip list an interface, in order to share the type independented method.
|
||||
// And because these methods are called directly without going through the interface, they are also
|
||||
// much faster.
|
||||
type skipListImpl[K any, V any] interface {
|
||||
findNode(key K) *skipListNode[K, V]
|
||||
lowerBound(key K) *skipListNode[K, V]
|
||||
upperBound(key K) *skipListNode[K, V]
|
||||
findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V])
|
||||
findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V])
|
||||
}
|
||||
|
||||
func (sl *SkipList[K, V]) init() {
|
||||
sl.level = 1
|
||||
// #nosec G404 -- This is not a security condition
|
||||
sl.rander = rand.New(rand.NewSource(time.Now().Unix()))
|
||||
sl.prevsCache = make([]*skipListNode[K, V], skipListMaxLevel)
|
||||
sl.head.next = make([]*skipListNode[K, V], skipListMaxLevel)
|
||||
}
|
||||
|
||||
func (sl *SkipList[K, V]) randomLevel() int {
|
||||
total := uint64(1)<<uint64(skipListMaxLevel) - 1 // 2^n-1
|
||||
k := sl.rander.Uint64() % total
|
||||
level := skipListMaxLevel - bits.Len64(k) + 1
|
||||
// Since levels are randomly generated, most should be less than log2(s.len).
|
||||
// Then make a limit according to sl.len to avoid unexpectedly large value.
|
||||
for level > 3 && 1<<(level-3) > sl.len {
|
||||
level--
|
||||
}
|
||||
|
||||
return level
|
||||
}
|
||||
|
||||
/// skipListOrdered part
|
||||
|
||||
// skipListOrdered is the skip list implementation for Ordered types.
|
||||
type skipListOrdered[K Ordered, V any] struct {
|
||||
SkipList[K, V]
|
||||
}
|
||||
|
||||
func (sl *skipListOrdered[K, V]) findNode(key K) *skipListNode[K, V] {
|
||||
return sl.doFindNode(key, true)
|
||||
}
|
||||
|
||||
func (sl *skipListOrdered[K, V]) doFindNode(key K, eq bool) *skipListNode[K, V] {
|
||||
// This function execute the job of findNode if eq is true, otherwise lowBound.
|
||||
// Passing the control variable eq is ugly but it's faster than testing node
|
||||
// again outside the function in findNode.
|
||||
prev := &sl.head
|
||||
for i := sl.level - 1; i >= 0; i-- {
|
||||
for cur := prev.next[i]; cur != nil; cur = cur.next[i] {
|
||||
if cur.key == key {
|
||||
return cur
|
||||
}
|
||||
if cur.key > key {
|
||||
// All other node in this level must be greater than the key,
|
||||
// search the next level.
|
||||
break
|
||||
}
|
||||
prev = cur
|
||||
}
|
||||
}
|
||||
if eq {
|
||||
return nil
|
||||
}
|
||||
return prev.next[0]
|
||||
}
|
||||
|
||||
func (sl *skipListOrdered[K, V]) lowerBound(key K) *skipListNode[K, V] {
|
||||
return sl.doFindNode(key, false)
|
||||
}
|
||||
|
||||
func (sl *skipListOrdered[K, V]) upperBound(key K) *skipListNode[K, V] {
|
||||
node := sl.lowerBound(key)
|
||||
if node != nil && node.key == key {
|
||||
return node.next[0]
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// findInsertPoint returns (*node, nil) to the existed node if the key exists,
|
||||
// or (nil, []*node) to the previous nodes if the key doesn't exist
|
||||
func (sl *skipListOrdered[K, V]) findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) {
|
||||
prevs := sl.prevsCache[0:sl.level]
|
||||
prev := &sl.head
|
||||
for i := sl.level - 1; i >= 0; i-- {
|
||||
for next := prev.next[i]; next != nil; next = next.next[i] {
|
||||
if next.key == key {
|
||||
// The key is already existed, prevs are useless because no new node insertion.
|
||||
// stop searching.
|
||||
return next, nil
|
||||
}
|
||||
if next.key > key {
|
||||
// All other node in this level must be greater than the key,
|
||||
// search the next level.
|
||||
break
|
||||
}
|
||||
prev = next
|
||||
}
|
||||
prevs[i] = prev
|
||||
}
|
||||
return nil, prevs
|
||||
}
|
||||
|
||||
// findRemovePoint finds the node which match the key and it's previous nodes.
|
||||
func (sl *skipListOrdered[K, V]) findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) {
|
||||
prevs := sl.findPrevNodes(key)
|
||||
node := prevs[0].next[0]
|
||||
if node == nil || node.key != key {
|
||||
return nil, nil
|
||||
}
|
||||
return node, prevs
|
||||
}
|
||||
|
||||
func (sl *skipListOrdered[K, V]) findPrevNodes(key K) []*skipListNode[K, V] {
|
||||
prevs := sl.prevsCache[0:sl.level]
|
||||
prev := &sl.head
|
||||
for i := sl.level - 1; i >= 0; i-- {
|
||||
for next := prev.next[i]; next != nil; next = next.next[i] {
|
||||
if next.key >= key {
|
||||
break
|
||||
}
|
||||
prev = next
|
||||
}
|
||||
prevs[i] = prev
|
||||
}
|
||||
return prevs
|
||||
}
|
||||
|
||||
/// skipListFunc part
|
||||
|
||||
// skipListFunc is the skip list implementation which compare keys with func.
|
||||
type skipListFunc[K any, V any] struct {
|
||||
SkipList[K, V]
|
||||
keyCmp CompareFn[K]
|
||||
}
|
||||
|
||||
func (sl *skipListFunc[K, V]) findNode(key K) *skipListNode[K, V] {
|
||||
node := sl.lowerBound(key)
|
||||
if node != nil && sl.keyCmp(node.key, key) == 0 {
|
||||
return node
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sl *skipListFunc[K, V]) lowerBound(key K) *skipListNode[K, V] {
|
||||
var prev = &sl.head
|
||||
for i := sl.level - 1; i >= 0; i-- {
|
||||
cur := prev.next[i]
|
||||
for ; cur != nil; cur = cur.next[i] {
|
||||
cmpRet := sl.keyCmp(cur.key, key)
|
||||
if cmpRet == 0 {
|
||||
return cur
|
||||
}
|
||||
if cmpRet > 0 {
|
||||
break
|
||||
}
|
||||
prev = cur
|
||||
}
|
||||
}
|
||||
return prev.next[0]
|
||||
}
|
||||
|
||||
func (sl *skipListFunc[K, V]) upperBound(key K) *skipListNode[K, V] {
|
||||
node := sl.lowerBound(key)
|
||||
if node != nil && sl.keyCmp(node.key, key) == 0 {
|
||||
return node.next[0]
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// findInsertPoint returns (*node, nil) to the existed node if the key exists,
|
||||
// or (nil, []*node) to the previous nodes if the key doesn't exist
|
||||
func (sl *skipListFunc[K, V]) findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) {
|
||||
prevs := sl.prevsCache[0:sl.level]
|
||||
prev := &sl.head
|
||||
for i := sl.level - 1; i >= 0; i-- {
|
||||
for cur := prev.next[i]; cur != nil; cur = cur.next[i] {
|
||||
r := sl.keyCmp(cur.key, key)
|
||||
if r == 0 {
|
||||
// The key is already existed, prevs are useless because no new node insertion.
|
||||
// stop searching.
|
||||
return cur, nil
|
||||
}
|
||||
if r > 0 {
|
||||
// All other node in this level must be greater than the key,
|
||||
// search the next level.
|
||||
break
|
||||
}
|
||||
prev = cur
|
||||
}
|
||||
prevs[i] = prev
|
||||
}
|
||||
return nil, prevs
|
||||
}
|
||||
|
||||
// findRemovePoint finds the node which match the key and it's previous nodes.
|
||||
func (sl *skipListFunc[K, V]) findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) {
|
||||
prevs := sl.findPrevNodes(key)
|
||||
node := prevs[0].next[0]
|
||||
if node == nil || sl.keyCmp(node.key, key) != 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return node, prevs
|
||||
}
|
||||
|
||||
func (sl *skipListFunc[K, V]) findPrevNodes(key K) []*skipListNode[K, V] {
|
||||
prevs := sl.prevsCache[0:sl.level]
|
||||
prev := &sl.head
|
||||
for i := sl.level - 1; i >= 0; i-- {
|
||||
for next := prev.next[i]; next != nil; next = next.next[i] {
|
||||
if sl.keyCmp(next.key, key) >= 0 {
|
||||
break
|
||||
}
|
||||
prev = next
|
||||
}
|
||||
prevs[i] = prev
|
||||
}
|
||||
return prevs
|
||||
}
|
297
transport/anytls/skiplist/skiplist_newnode.go
Normal file
297
transport/anytls/skiplist/skiplist_newnode.go
Normal file
|
@ -0,0 +1,297 @@
|
|||
// AUTO GENERATED CODE, DON'T EDIT!!!
|
||||
// EDIT skiplist_newnode_generate.sh accordingly.
|
||||
|
||||
package skiplist
|
||||
|
||||
// newSkipListNode creates a new node initialized with specified key, value and next slice.
|
||||
func newSkipListNode[K any, V any](level int, key K, value V) *skipListNode[K, V] {
|
||||
// For nodes with each levels, point their next slice to the nexts array allocated together,
|
||||
// which can reduce 1 memory allocation and improve performance.
|
||||
//
|
||||
// The generics of the golang doesn't support non-type parameters like in C++,
|
||||
// so we have to generate it manually.
|
||||
switch level {
|
||||
case 1:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [1]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 2:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [2]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 3:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [3]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 4:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [4]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 5:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [5]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 6:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [6]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 7:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [7]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 8:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [8]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 9:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [9]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 10:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [10]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 11:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [11]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 12:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [12]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 13:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [13]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 14:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [14]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 15:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [15]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 16:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [16]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 17:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [17]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 18:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [18]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 19:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [19]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 20:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [20]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 21:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [21]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 22:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [22]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 23:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [23]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 24:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [24]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 25:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [25]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 26:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [26]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 27:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [27]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 28:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [28]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 29:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [29]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 30:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [30]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 31:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [31]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 32:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [32]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 33:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [33]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 34:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [34]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 35:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [35]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 36:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [36]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 37:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [37]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 38:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [38]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 39:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [39]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
case 40:
|
||||
n := struct {
|
||||
head skipListNode[K, V]
|
||||
nexts [40]*skipListNode[K, V]
|
||||
}{head: skipListNode[K, V]{key, value, nil}}
|
||||
n.head.next = n.nexts[:]
|
||||
return &n.head
|
||||
}
|
||||
|
||||
panic("should not reach here")
|
||||
}
|
75
transport/anytls/skiplist/types.go
Normal file
75
transport/anytls/skiplist/types.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package skiplist
|
||||
|
||||
// Signed is a constraint that permits any signed integer type.
|
||||
// If future releases of Go add new predeclared signed integer types,
|
||||
// this constraint will be modified to include them.
|
||||
type Signed interface {
|
||||
~int | ~int8 | ~int16 | ~int32 | ~int64
|
||||
}
|
||||
|
||||
// Unsigned is a constraint that permits any unsigned integer type.
|
||||
// If future releases of Go add new predeclared unsigned integer types,
|
||||
// this constraint will be modified to include them.
|
||||
type Unsigned interface {
|
||||
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr
|
||||
}
|
||||
|
||||
// Integer is a constraint that permits any integer type.
|
||||
// If future releases of Go add new predeclared integer types,
|
||||
// this constraint will be modified to include them.
|
||||
type Integer interface {
|
||||
Signed | Unsigned
|
||||
}
|
||||
|
||||
// Float is a constraint that permits any floating-point type.
|
||||
// If future releases of Go add new predeclared floating-point types,
|
||||
// this constraint will be modified to include them.
|
||||
type Float interface {
|
||||
~float32 | ~float64
|
||||
}
|
||||
|
||||
// Ordered is a constraint that permits any ordered type: any type
|
||||
// that supports the operators < <= >= >.
|
||||
// If future releases of Go add new ordered types,
|
||||
// this constraint will be modified to include them.
|
||||
type Ordered interface {
|
||||
Integer | Float | ~string
|
||||
}
|
||||
|
||||
// Numeric is a constraint that permits any numeric type.
|
||||
type Numeric interface {
|
||||
Integer | Float
|
||||
}
|
||||
|
||||
// LessFn is a function that returns whether 'a' is less than 'b'.
|
||||
type LessFn[T any] func(a, b T) bool
|
||||
|
||||
// CompareFn is a 3 way compare function that
|
||||
// returns 1 if a > b,
|
||||
// returns 0 if a == b,
|
||||
// returns -1 if a < b.
|
||||
type CompareFn[T any] func(a, b T) int
|
||||
|
||||
// HashFn is a function that returns the hash of 't'.
|
||||
type HashFn[T any] func(t T) uint64
|
||||
|
||||
// Equals wraps the '==' operator for comparable types.
|
||||
func Equals[T comparable](a, b T) bool {
|
||||
return a == b
|
||||
}
|
||||
|
||||
// Less wraps the '<' operator for ordered types.
|
||||
func Less[T Ordered](a, b T) bool {
|
||||
return a < b
|
||||
}
|
||||
|
||||
// OrderedCompare provide default CompareFn for ordered types.
|
||||
func OrderedCompare[T Ordered](a, b T) int {
|
||||
if a < b {
|
||||
return -1
|
||||
}
|
||||
if a > b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
28
transport/anytls/util/routine.go
Normal file
28
transport/anytls/util/routine.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/log"
|
||||
)
|
||||
|
||||
func StartRoutine(ctx context.Context, d time.Duration, f func()) {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorln("[BUG] %v %s", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
time.Sleep(d)
|
||||
f()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
27
transport/anytls/util/string_map.go
Normal file
27
transport/anytls/util/string_map.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type StringMap map[string]string
|
||||
|
||||
func (s StringMap) ToBytes() []byte {
|
||||
var lines []string
|
||||
for k, v := range s {
|
||||
lines = append(lines, k+"="+v)
|
||||
}
|
||||
return []byte(strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
func StringMapFromBytes(b []byte) StringMap {
|
||||
var m = make(StringMap)
|
||||
var lines = strings.Split(string(b), "\n")
|
||||
for _, line := range lines {
|
||||
v := strings.SplitN(line, "=", 2)
|
||||
if len(v) == 2 {
|
||||
m[v[0]] = v[1]
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
Loading…
Add table
Reference in a new issue