From 9962a0d091e017759b5a9a2a9a00c58ca0a303a9 Mon Sep 17 00:00:00 2001 From: anytls Date: Mon, 17 Feb 2025 18:51:11 +0800 Subject: [PATCH] feat: implement anytls client and server (#1844) --- adapter/outbound/anytls.go | 137 ++++++ adapter/parser.go | 7 + constant/adapters.go | 3 + constant/metadata.go | 5 + docs/config.yaml | 28 ++ listener/anytls/server.go | 181 +++++++ listener/config/anytls.go | 19 + listener/inbound/anytls.go | 79 +++ listener/parse.go | 7 + transport/anytls/client.go | 123 +++++ transport/anytls/padding/padding.go | 92 ++++ transport/anytls/session/client.go | 160 ++++++ transport/anytls/session/frame.go | 44 ++ transport/anytls/session/session.go | 379 +++++++++++++++ transport/anytls/session/stream.go | 99 ++++ transport/anytls/skiplist/contianer.go | 46 ++ transport/anytls/skiplist/skiplist.go | 455 ++++++++++++++++++ transport/anytls/skiplist/skiplist_newnode.go | 297 ++++++++++++ transport/anytls/skiplist/types.go | 75 +++ transport/anytls/util/routine.go | 28 ++ transport/anytls/util/string_map.go | 27 ++ 21 files changed, 2291 insertions(+) create mode 100644 adapter/outbound/anytls.go create mode 100644 listener/anytls/server.go create mode 100644 listener/config/anytls.go create mode 100644 listener/inbound/anytls.go create mode 100644 transport/anytls/client.go create mode 100644 transport/anytls/padding/padding.go create mode 100644 transport/anytls/session/client.go create mode 100644 transport/anytls/session/frame.go create mode 100644 transport/anytls/session/session.go create mode 100644 transport/anytls/session/stream.go create mode 100644 transport/anytls/skiplist/contianer.go create mode 100644 transport/anytls/skiplist/skiplist.go create mode 100644 transport/anytls/skiplist/skiplist_newnode.go create mode 100644 transport/anytls/skiplist/types.go create mode 100644 transport/anytls/util/routine.go create mode 100644 transport/anytls/util/string_map.go diff --git a/adapter/outbound/anytls.go b/adapter/outbound/anytls.go new file mode 100644 index 00000000..8af33f20 --- /dev/null +++ b/adapter/outbound/anytls.go @@ -0,0 +1,137 @@ +package outbound + +import ( + "context" + "crypto/tls" + "errors" + "net" + "runtime" + "strconv" + "time" + + CN "github.com/metacubex/mihomo/common/net" + "github.com/metacubex/mihomo/component/dialer" + "github.com/metacubex/mihomo/component/proxydialer" + "github.com/metacubex/mihomo/component/resolver" + tlsC "github.com/metacubex/mihomo/component/tls" + C "github.com/metacubex/mihomo/constant" + "github.com/metacubex/mihomo/transport/anytls" + "github.com/sagernet/sing/common" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/uot" +) + +type AnyTLS struct { + *Base + client *anytls.Client + dialer proxydialer.SingDialer + option *AnyTLSOption +} + +type AnyTLSOption struct { + BasicOption + Name string `proxy:"name"` + Server string `proxy:"server"` + Port int `proxy:"port"` + Password string `proxy:"password"` + ALPN []string `proxy:"alpn,omitempty"` + SNI string `proxy:"sni,omitempty"` + ClientFingerprint string `proxy:"client-fingerprint,omitempty"` + SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` + UDP bool `proxy:"udp,omitempty"` + IdleSessionCheckInterval int `proxy:"idle-session-check-interval,omitempty"` + IdleSessionTimeout int `proxy:"idle-session-timeout,omitempty"` +} + +func (t *AnyTLS) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { + options := t.Base.DialOptions(opts...) + t.dialer.SetDialer(dialer.NewDialer(options...)) + c, err := t.client.CreateProxy(ctx, M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort)) + if err != nil { + return nil, err + } + return NewConn(CN.NewRefConn(c, t), t), nil +} + +func (t *AnyTLS) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { + // create tcp + options := t.Base.DialOptions(opts...) + t.dialer.SetDialer(dialer.NewDialer(options...)) + c, err := t.client.CreateProxy(ctx, uot.RequestDestination(2)) + if err != nil { + return nil, err + } + + // create uot on tcp + if !metadata.Resolved() { + ip, err := resolver.ResolveIP(ctx, metadata.Host) + if err != nil { + return nil, errors.New("can't resolve ip") + } + metadata.DstIP = ip + } + destination := M.SocksaddrFromNet(metadata.UDPAddr()) + return newPacketConn(CN.NewRefPacketConn(CN.NewThreadSafePacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination})), t), t), nil +} + +// SupportUOT implements C.ProxyAdapter +func (t *AnyTLS) SupportUOT() bool { + return true +} + +// ProxyInfo implements C.ProxyAdapter +func (t *AnyTLS) ProxyInfo() C.ProxyInfo { + info := t.Base.ProxyInfo() + info.DialerProxy = t.option.DialerProxy + return info +} + +func NewAnyTLS(option AnyTLSOption) (*AnyTLS, error) { + addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) + + singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()) + + tOption := anytls.ClientConfig{ + Password: option.Password, + Server: M.ParseSocksaddrHostPort(option.Server, uint16(option.Port)), + Dialer: singDialer, + IdleSessionCheckInterval: time.Duration(option.IdleSessionCheckInterval) * time.Second, + IdleSessionTimeout: time.Duration(option.IdleSessionTimeout) * time.Second, + ClientFingerprint: option.ClientFingerprint, + } + tlsConfig := &tls.Config{ + ServerName: option.SNI, + InsecureSkipVerify: option.SkipCertVerify, + NextProtos: option.ALPN, + } + if tlsConfig.ServerName == "" { + tlsConfig.ServerName = "127.0.0.1" + } + tOption.TLSConfig = tlsConfig + + if tlsC.HaveGlobalFingerprint() && len(tOption.ClientFingerprint) == 0 { + tOption.ClientFingerprint = tlsC.GetGlobalFingerprint() + } + + outbound := &AnyTLS{ + Base: &Base{ + name: option.Name, + addr: addr, + tp: C.AnyTLS, + udp: option.UDP, + tfo: option.TFO, + mpTcp: option.MPTCP, + iface: option.Interface, + rmark: option.RoutingMark, + prefer: C.NewDNSPrefer(option.IPVersion), + }, + client: anytls.NewClient(context.TODO(), tOption), + option: &option, + dialer: singDialer, + } + runtime.SetFinalizer(outbound, func(o *AnyTLS) { + common.Close(o.client) + }) + + return outbound, nil +} diff --git a/adapter/parser.go b/adapter/parser.go index ce4e91d5..9b256e6d 100644 --- a/adapter/parser.go +++ b/adapter/parser.go @@ -148,6 +148,13 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) { break } proxy, err = outbound.NewMieru(*mieruOption) + case "anytls": + anytlsOption := &outbound.AnyTLSOption{} + err = decoder.Decode(mapping, anytlsOption) + if err != nil { + break + } + proxy, err = outbound.NewAnyTLS(*anytlsOption) default: return nil, fmt.Errorf("unsupport proxy type: %s", proxyType) } diff --git a/constant/adapters.go b/constant/adapters.go index 420a797f..b6b104c9 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -43,6 +43,7 @@ const ( Tuic Ssh Mieru + AnyTLS ) const ( @@ -229,6 +230,8 @@ func (at AdapterType) String() string { return "Ssh" case Mieru: return "Mieru" + case AnyTLS: + return "AnyTLS" case Relay: return "Relay" case Selector: diff --git a/constant/metadata.go b/constant/metadata.go index e4167845..ac676b04 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -32,6 +32,7 @@ const ( TUN TUIC HYSTERIA2 + ANYTLS INNER ) @@ -84,6 +85,8 @@ func (t Type) String() string { return "Tuic" case HYSTERIA2: return "Hysteria2" + case ANYTLS: + return "AnyTLS" case INNER: return "Inner" default: @@ -120,6 +123,8 @@ func ParseType(t string) (*Type, error) { res = TUIC case "HYSTERIA2": res = HYSTERIA2 + case "ANYTLS": + res = ANYTLS case "INNER": res = INNER default: diff --git a/docs/config.yaml b/docs/config.yaml index ad77debb..db66a2b3 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -864,6 +864,22 @@ proxies: # socks5 # 可以使用的值包括 MULTIPLEXING_OFF, MULTIPLEXING_LOW, MULTIPLEXING_MIDDLE, MULTIPLEXING_HIGH。其中 MULTIPLEXING_OFF 会关闭多路复用功能。默认值为 MULTIPLEXING_LOW。 # multiplexing: MULTIPLEXING_LOW + # anytls + - name: anytls + type: anytls + server: 1.2.3.4 + port: 443 + password: "" + # client-fingerprint: chrome + udp: true + # idle-session-check-interval: 30 # seconds + # idle-session-timeout: 30 # seconds + # sni: "example.com" + # alpn: + # - h2 + # - http/1.1 + # skip-cert-verify: true + # dns 出站会将请求劫持到内部 dns 模块,所有请求均在内部处理 - name: "dns-out" type: dns @@ -1209,6 +1225,18 @@ listeners: - test.com ### 注意,对于vless listener, 至少需要填写 “certificate和private-key” 或 “reality-config” 的其中一项 ### + - name: anytls-in-1 + type: anytls + port: 10818 + listen: 0.0.0.0 + users: + username1: password1 + username2: password2 + # "certificate" and "private-key" are required + certificate: ./server.crt + private-key: ./server.key + # padding-scheme: "" # https://github.com/anytls/anytls-go/blob/main/docs/protocol.md#cmdupdatepaddingscheme + - name: tun-in-1 type: tun # rule: sub-rule-name1 # 默认使用 rules,如果未找到 sub-rule 则直接使用 rules diff --git a/listener/anytls/server.go b/listener/anytls/server.go new file mode 100644 index 00000000..5d860e8a --- /dev/null +++ b/listener/anytls/server.go @@ -0,0 +1,181 @@ +package anytls + +import ( + "context" + "crypto/sha256" + "crypto/tls" + "encoding/binary" + "errors" + "net" + "strings" + + "github.com/metacubex/mihomo/adapter/inbound" + "github.com/metacubex/mihomo/common/buf" + N "github.com/metacubex/mihomo/common/net" + C "github.com/metacubex/mihomo/constant" + LC "github.com/metacubex/mihomo/listener/config" + "github.com/metacubex/mihomo/listener/sing" + "github.com/metacubex/mihomo/transport/anytls/padding" + "github.com/metacubex/mihomo/transport/anytls/session" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/bufio" + M "github.com/sagernet/sing/common/metadata" +) + +type Listener struct { + closed bool + config LC.AnyTLSServer + listeners []net.Listener + tlsConfig *tls.Config + userMap map[[32]byte]string + padding atomic.TypedValue[*padding.PaddingFactory] +} + +func New(config LC.AnyTLSServer, tunnel C.Tunnel, additions ...inbound.Addition) (sl *Listener, err error) { + if len(additions) == 0 { + additions = []inbound.Addition{ + inbound.WithInName("DEFAULT-ANYTLS"), + inbound.WithSpecialRules(""), + } + } + + tlsConfig := &tls.Config{} + if config.Certificate != "" && config.PrivateKey != "" { + cert, err := N.ParseCert(config.Certificate, config.PrivateKey, C.Path) + if err != nil { + return nil, err + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + sl = &Listener{ + config: config, + tlsConfig: tlsConfig, + userMap: make(map[[32]byte]string), + } + + for user, password := range config.Users { + sl.userMap[sha256.Sum256([]byte(password))] = user + } + + if len(config.PaddingScheme) > 0 { + if !padding.UpdatePaddingScheme([]byte(config.PaddingScheme), &sl.padding) { + return nil, errors.New("incorrect padding scheme format") + } + } else { + padding.UpdatePaddingScheme(padding.DefaultPaddingScheme, &sl.padding) + } + + // Using sing handler can automatically handle UoT + h, err := sing.NewListenerHandler(sing.ListenerConfig{ + Tunnel: tunnel, + Type: C.ANYTLS, + Additions: additions, + }) + if err != nil { + return nil, err + } + + for _, addr := range strings.Split(config.Listen, ",") { + addr := addr + + //TCP + l, err := inbound.Listen("tcp", addr) + if err != nil { + return nil, err + } + sl.listeners = append(sl.listeners, l) + + go func() { + for { + c, err := l.Accept() + if err != nil { + if sl.closed { + break + } + continue + } + go sl.HandleConn(c, h) + } + }() + } + + return sl, nil +} + +func (l *Listener) Close() error { + l.closed = true + var retErr error + for _, lis := range l.listeners { + err := lis.Close() + if err != nil { + retErr = err + } + } + return retErr +} + +func (l *Listener) Config() string { + return l.config.String() +} + +func (l *Listener) AddrList() (addrList []net.Addr) { + for _, lis := range l.listeners { + addrList = append(addrList, lis.Addr()) + } + return +} + +func (l *Listener) HandleConn(conn net.Conn, h *sing.ListenerHandler) { + ctx := context.TODO() + + conn = tls.Server(conn, l.tlsConfig) + defer conn.Close() + + b := buf.NewPacket() + _, err := b.ReadOnceFrom(conn) + if err != nil { + return + } + conn = bufio.NewCachedConn(conn, b) + + by, err := b.ReadBytes(32) + if err != nil { + return + } + var passwordSha256 [32]byte + copy(passwordSha256[:], by) + if user, ok := l.userMap[passwordSha256]; ok { + ctx = auth.ContextWithUser(ctx, user) + } else { + return + } + by, err = b.ReadBytes(2) + if err != nil { + return + } + paddingLen := binary.BigEndian.Uint16(by) + if paddingLen > 0 { + _, err = b.ReadBytes(int(paddingLen)) + if err != nil { + return + } + } + + session := session.NewServerSession(conn, func(stream *session.Stream) { + defer stream.Close() + + destination, err := M.SocksaddrSerializer.ReadAddrPort(stream) + if err != nil { + return + } + + h.NewConnection(ctx, stream, M.Metadata{ + Source: M.SocksaddrFromNet(conn.RemoteAddr()), + Destination: destination, + }) + }, &l.padding) + session.Run(true) + session.Close() +} diff --git a/listener/config/anytls.go b/listener/config/anytls.go new file mode 100644 index 00000000..adbafa60 --- /dev/null +++ b/listener/config/anytls.go @@ -0,0 +1,19 @@ +package config + +import ( + "encoding/json" +) + +type AnyTLSServer struct { + Enable bool `yaml:"enable" json:"enable"` + Listen string `yaml:"listen" json:"listen"` + Users map[string]string `yaml:"users" json:"users,omitempty"` + Certificate string `yaml:"certificate" json:"certificate"` + PrivateKey string `yaml:"private-key" json:"private-key"` + PaddingScheme string `yaml:"padding-scheme" json:"padding-scheme,omitempty"` +} + +func (t AnyTLSServer) String() string { + b, _ := json.Marshal(t) + return string(b) +} diff --git a/listener/inbound/anytls.go b/listener/inbound/anytls.go new file mode 100644 index 00000000..a995bf4f --- /dev/null +++ b/listener/inbound/anytls.go @@ -0,0 +1,79 @@ +package inbound + +import ( + C "github.com/metacubex/mihomo/constant" + "github.com/metacubex/mihomo/listener/anytls" + LC "github.com/metacubex/mihomo/listener/config" + "github.com/metacubex/mihomo/log" +) + +type AnyTLSOption struct { + BaseOption + Users map[string]string `inbound:"users,omitempty"` + Certificate string `inbound:"certificate"` + PrivateKey string `inbound:"private-key"` + PaddingScheme string `inbound:"padding-scheme,omitempty"` +} + +func (o AnyTLSOption) Equal(config C.InboundConfig) bool { + return optionToString(o) == optionToString(config) +} + +type AnyTLS struct { + *Base + config *AnyTLSOption + l C.MultiAddrListener + vs LC.AnyTLSServer +} + +func NewAnyTLS(options *AnyTLSOption) (*AnyTLS, error) { + base, err := NewBase(&options.BaseOption) + if err != nil { + return nil, err + } + return &AnyTLS{ + Base: base, + config: options, + vs: LC.AnyTLSServer{ + Enable: true, + Listen: base.RawAddress(), + Users: options.Users, + Certificate: options.Certificate, + PrivateKey: options.PrivateKey, + PaddingScheme: options.PaddingScheme, + }, + }, nil +} + +// Config implements constant.InboundListener +func (v *AnyTLS) Config() C.InboundConfig { + return v.config +} + +// Address implements constant.InboundListener +func (v *AnyTLS) Address() string { + if v.l != nil { + for _, addr := range v.l.AddrList() { + return addr.String() + } + } + return "" +} + +// Listen implements constant.InboundListener +func (v *AnyTLS) Listen(tunnel C.Tunnel) error { + var err error + v.l, err = anytls.New(v.vs, tunnel, v.Additions()...) + if err != nil { + return err + } + log.Infoln("AnyTLS[%s] proxy listening at: %s", v.Name(), v.Address()) + return nil +} + +// Close implements constant.InboundListener +func (v *AnyTLS) Close() error { + return v.l.Close() +} + +var _ C.InboundListener = (*AnyTLS)(nil) diff --git a/listener/parse.go b/listener/parse.go index 38082e92..5c5d6c7e 100644 --- a/listener/parse.go +++ b/listener/parse.go @@ -113,6 +113,13 @@ func ParseListener(mapping map[string]any) (C.InboundListener, error) { return nil, err } listener, err = IN.NewTuic(tuicOption) + case "anytls": + anytlsOption := &IN.AnyTLSOption{} + err = decoder.Decode(mapping, anytlsOption) + if err != nil { + return nil, err + } + listener, err = IN.NewAnyTLS(anytlsOption) default: return nil, fmt.Errorf("unsupport proxy type: %s", proxyType) } diff --git a/transport/anytls/client.go b/transport/anytls/client.go new file mode 100644 index 00000000..2076019e --- /dev/null +++ b/transport/anytls/client.go @@ -0,0 +1,123 @@ +package anytls + +import ( + "context" + "crypto/sha256" + "crypto/tls" + "encoding/binary" + "net" + "time" + + tlsC "github.com/metacubex/mihomo/component/tls" + C "github.com/metacubex/mihomo/constant" + "github.com/metacubex/mihomo/transport/anytls/padding" + "github.com/metacubex/mihomo/transport/anytls/session" + "github.com/metacubex/mihomo/transport/vmess" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type ClientConfig struct { + Password string + IdleSessionCheckInterval time.Duration + IdleSessionTimeout time.Duration + Server M.Socksaddr + Dialer N.Dialer + TLSConfig *tls.Config + ClientFingerprint string +} + +type Client struct { + passwordSha256 []byte + tlsConfig *tls.Config + clientFingerprint string + dialer N.Dialer + server M.Socksaddr + sessionClient *session.Client + padding atomic.TypedValue[*padding.PaddingFactory] +} + +func NewClient(ctx context.Context, config ClientConfig) *Client { + pw := sha256.Sum256([]byte(config.Password)) + c := &Client{ + passwordSha256: pw[:], + tlsConfig: config.TLSConfig, + clientFingerprint: config.ClientFingerprint, + dialer: config.Dialer, + server: config.Server, + } + // Initialize the padding state of this client + padding.UpdatePaddingScheme(padding.DefaultPaddingScheme, &c.padding) + c.sessionClient = session.NewClient(ctx, c.CreateOutboundTLSConnection, &c.padding, config.IdleSessionCheckInterval, config.IdleSessionTimeout) + return c +} + +func (c *Client) CreateProxy(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { + conn, err := c.sessionClient.CreateStream(ctx) + if err != nil { + return nil, err + } + err = M.SocksaddrSerializer.WriteAddrPort(conn, destination) + if err != nil { + conn.Close() + return nil, err + } + return conn, nil +} + +func (c *Client) CreateOutboundTLSConnection(ctx context.Context) (net.Conn, error) { + conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server) + if err != nil { + return nil, err + } + + b := buf.NewPacket() + b.Write(c.passwordSha256) + var paddingLen int + if pad := c.padding.Load().GenerateRecordPayloadSizes(0); len(pad) > 0 { + paddingLen = pad[0] + } + binary.BigEndian.PutUint16(b.Extend(2), uint16(paddingLen)) + if paddingLen > 0 { + b.WriteZeroN(paddingLen) + } + + getTlsConn := func() (net.Conn, error) { + if len(c.clientFingerprint) != 0 { + utlsConn, valid := vmess.GetUTLSConn(conn, c.clientFingerprint, c.tlsConfig) + if valid { + ctx, cancel := context.WithTimeout(ctx, C.DefaultTLSTimeout) + defer cancel() + + err := utlsConn.(*tlsC.UConn).HandshakeContext(ctx) + return utlsConn, err + } + } + + tlsConn := tls.Client(conn, c.tlsConfig) + + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) + defer cancel() + + err = tlsConn.HandshakeContext(ctx) + return tlsConn, err + } + tlsConn, err := getTlsConn() + if err != nil { + conn.Close() + return nil, err + } + + _, err = b.WriteTo(tlsConn) + if err != nil { + tlsConn.Close() + return nil, err + } + return tlsConn, nil +} + +func (h *Client) Close() error { + return h.sessionClient.Close() +} diff --git a/transport/anytls/padding/padding.go b/transport/anytls/padding/padding.go new file mode 100644 index 00000000..e881e573 --- /dev/null +++ b/transport/anytls/padding/padding.go @@ -0,0 +1,92 @@ +package padding + +import ( + "crypto/md5" + "crypto/rand" + "fmt" + "math/big" + "strconv" + "strings" + + "github.com/metacubex/mihomo/transport/anytls/util" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" +) + +const CheckMark = -1 + +var DefaultPaddingScheme = []byte(`stop=8 +0=34-120 +1=100-400 +2=400-500,c,500-1000,c,400-500,c,500-1000,c,500-1000,c,400-500 +3=500-1000 +4=500-1000 +5=500-1000 +6=500-1000 +7=500-1000`) + +type PaddingFactory struct { + scheme util.StringMap + RawScheme []byte + Stop uint32 + Md5 string +} + +func UpdatePaddingScheme(rawScheme []byte, to *atomic.TypedValue[*PaddingFactory]) bool { + if p := NewPaddingFactory(rawScheme); p != nil { + to.Store(p) + return true + } + return false +} + +func NewPaddingFactory(rawScheme []byte) *PaddingFactory { + p := &PaddingFactory{ + RawScheme: rawScheme, + Md5: fmt.Sprintf("%x", md5.Sum(rawScheme)), + } + scheme := util.StringMapFromBytes(rawScheme) + if len(scheme) == 0 { + return nil + } + if stop, err := strconv.Atoi(scheme["stop"]); err == nil { + p.Stop = uint32(stop) + } else { + return nil + } + p.scheme = scheme + return p +} + +func (p *PaddingFactory) GenerateRecordPayloadSizes(pkt uint32) (pktSizes []int) { + if s, ok := p.scheme[strconv.Itoa(int(pkt))]; ok { + sRanges := strings.Split(s, ",") + for _, sRange := range sRanges { + sRangeMinMax := strings.Split(sRange, "-") + if len(sRangeMinMax) == 2 { + _min, err := strconv.ParseInt(sRangeMinMax[0], 10, 64) + if err != nil { + continue + } + _max, err := strconv.ParseInt(sRangeMinMax[1], 10, 64) + if err != nil { + continue + } + _min, _max = common.Min(_min, _max), common.Max(_min, _max) + if _min <= 0 || _max <= 0 { + continue + } + if _min == _max { + pktSizes = append(pktSizes, int(_min)) + } else { + i, _ := rand.Int(rand.Reader, big.NewInt(_max-_min)) + pktSizes = append(pktSizes, int(i.Int64()+_min)) + } + } else if sRange == "c" { + pktSizes = append(pktSizes, CheckMark) + } + } + } + return +} diff --git a/transport/anytls/session/client.go b/transport/anytls/session/client.go new file mode 100644 index 00000000..5a853478 --- /dev/null +++ b/transport/anytls/session/client.go @@ -0,0 +1,160 @@ +package session + +import ( + "context" + "fmt" + "io" + "math" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/metacubex/mihomo/transport/anytls/padding" + "github.com/metacubex/mihomo/transport/anytls/skiplist" + "github.com/metacubex/mihomo/transport/anytls/util" + "github.com/sagernet/sing/common" + singAtomic "github.com/sagernet/sing/common/atomic" +) + +type Client struct { + die context.Context + dieCancel context.CancelFunc + + dialOut func(ctx context.Context) (net.Conn, error) + + sessionCounter atomic.Uint64 + idleSession *skiplist.SkipList[uint64, *Session] + idleSessionLock sync.Mutex + + padding *singAtomic.TypedValue[*padding.PaddingFactory] + + idleSessionTimeout time.Duration +} + +func NewClient(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error), _padding *singAtomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration) *Client { + c := &Client{ + dialOut: dialOut, + padding: _padding, + idleSessionTimeout: idleSessionTimeout, + } + if idleSessionCheckInterval <= time.Second*5 { + idleSessionCheckInterval = time.Second * 30 + } + if c.idleSessionTimeout <= time.Second*5 { + c.idleSessionTimeout = time.Second * 30 + } + c.die, c.dieCancel = context.WithCancel(ctx) + c.idleSession = skiplist.NewSkipList[uint64, *Session]() + util.StartRoutine(c.die, idleSessionCheckInterval, c.idleCleanup) + return c +} + +func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) { + select { + case <-c.die.Done(): + return nil, io.ErrClosedPipe + default: + } + + var session *Session + var stream *Stream + var err error + + for i := 0; i < 3; i++ { + session, err = c.findSession(ctx) + if session == nil { + return nil, fmt.Errorf("failed to create session: %w", err) + } + stream, err = session.OpenStream() + if err != nil { + common.Close(session, stream) + continue + } + break + } + if session == nil || stream == nil { + return nil, fmt.Errorf("too many closed session: %w", err) + } + + stream.dieHook = func() { + if session.IsClosed() { + if session.dieHook != nil { + session.dieHook() + } + } else { + c.idleSessionLock.Lock() + session.idleSince = time.Now() + c.idleSession.Insert(math.MaxUint64-session.seq, session) + c.idleSessionLock.Unlock() + } + } + + return stream, nil +} + +func (c *Client) findSession(ctx context.Context) (*Session, error) { + var idle *Session + + c.idleSessionLock.Lock() + if !c.idleSession.IsEmpty() { + it := c.idleSession.Iterate() + idle = it.Value() + c.idleSession.Remove(it.Key()) + } + c.idleSessionLock.Unlock() + + if idle == nil { + s, err := c.createSession(ctx) + return s, err + } + return idle, nil +} + +func (c *Client) createSession(ctx context.Context) (*Session, error) { + underlying, err := c.dialOut(ctx) + if err != nil { + return nil, err + } + + session := NewClientSession(underlying, c.padding) + session.seq = c.sessionCounter.Add(1) + session.dieHook = func() { + //logrus.Debugln("session died", session) + c.idleSessionLock.Lock() + c.idleSession.Remove(math.MaxUint64 - session.seq) + c.idleSessionLock.Unlock() + } + session.Run(false) + return session, nil +} + +func (c *Client) Close() error { + c.dieCancel() + go c.idleCleanupExpTime(time.Now()) + return nil +} + +func (c *Client) idleCleanup() { + c.idleCleanupExpTime(time.Now().Add(-c.idleSessionTimeout)) +} + +func (c *Client) idleCleanupExpTime(expTime time.Time) { + var sessionToRemove = make([]*Session, 0) + + c.idleSessionLock.Lock() + it := c.idleSession.Iterate() + for it.IsNotEnd() { + session := it.Value() + if session.idleSince.Before(expTime) { + sessionToRemove = append(sessionToRemove, session) + c.idleSession.Remove(it.Key()) + } + it.MoveToNext() + } + c.idleSessionLock.Unlock() + + for _, session := range sessionToRemove { + session.Close() + } +} diff --git a/transport/anytls/session/frame.go b/transport/anytls/session/frame.go new file mode 100644 index 00000000..49597c55 --- /dev/null +++ b/transport/anytls/session/frame.go @@ -0,0 +1,44 @@ +package session + +import ( + "encoding/binary" +) + +const ( // cmds + cmdWaste = 0 // Paddings + cmdSYN = 1 // stream open + cmdPSH = 2 // data push + cmdFIN = 3 // stream close, a.k.a EOF mark + cmdSettings = 4 // Settings + cmdAlert = 5 // Alert + cmdUpdatePaddingScheme = 6 // update padding scheme +) + +const ( + headerOverHeadSize = 1 + 4 + 2 +) + +// frame defines a packet from or to be multiplexed into a single connection +type frame struct { + cmd byte // 1 + sid uint32 // 4 + data []byte // 2 + len(data) +} + +func newFrame(cmd byte, sid uint32) frame { + return frame{cmd: cmd, sid: sid} +} + +type rawHeader [headerOverHeadSize]byte + +func (h rawHeader) Cmd() byte { + return h[0] +} + +func (h rawHeader) StreamID() uint32 { + return binary.BigEndian.Uint32(h[1:]) +} + +func (h rawHeader) Length() uint16 { + return binary.BigEndian.Uint16(h[5:]) +} diff --git a/transport/anytls/session/session.go b/transport/anytls/session/session.go new file mode 100644 index 00000000..e7186b63 --- /dev/null +++ b/transport/anytls/session/session.go @@ -0,0 +1,379 @@ +package session + +import ( + "crypto/md5" + "encoding/binary" + "io" + "net" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "github.com/metacubex/mihomo/constant" + "github.com/metacubex/mihomo/log" + "github.com/metacubex/mihomo/transport/anytls/padding" + "github.com/metacubex/mihomo/transport/anytls/util" + singAtomic "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/buf" +) + +type Session struct { + conn net.Conn + connLock sync.Mutex + + streams map[uint32]*Stream + streamId atomic.Uint32 + streamLock sync.RWMutex + + dieOnce sync.Once + die chan struct{} + dieHook func() + + // pool + seq uint64 + idleSince time.Time + padding *singAtomic.TypedValue[*padding.PaddingFactory] + + // client + isClient bool + buffering bool + buffer []byte + pktCounter atomic.Uint32 + + // server + onNewStream func(stream *Stream) +} + +func NewClientSession(conn net.Conn, _padding *singAtomic.TypedValue[*padding.PaddingFactory]) *Session { + s := &Session{ + conn: conn, + isClient: true, + padding: _padding, + } + s.die = make(chan struct{}) + s.streams = make(map[uint32]*Stream) + return s +} + +func NewServerSession(conn net.Conn, onNewStream func(stream *Stream), _padding *singAtomic.TypedValue[*padding.PaddingFactory]) *Session { + s := &Session{ + conn: conn, + onNewStream: onNewStream, + isClient: false, + padding: _padding, + } + s.die = make(chan struct{}) + s.streams = make(map[uint32]*Stream) + return s +} + +func (s *Session) Run(isServer bool) { + if isServer { + s.recvLoop() + return + } + + settings := util.StringMap{ + "v": "1", + "client": "mihomo/" + constant.Version, + "padding-md5": s.padding.Load().Md5, + } + f := newFrame(cmdSettings, 0) + f.data = settings.ToBytes() + s.buffering = true + s.writeFrame(f) + + go s.recvLoop() +} + +// IsClosed does a safe check to see if we have shutdown +func (s *Session) IsClosed() bool { + select { + case <-s.die: + return true + default: + return false + } +} + +// Close is used to close the session and all streams. +func (s *Session) Close() error { + var once bool + s.dieOnce.Do(func() { + close(s.die) + once = true + }) + + if once { + if s.dieHook != nil { + s.dieHook() + } + s.streamLock.Lock() + for k := range s.streams { + s.streams[k].sessionClose() + } + s.streamLock.Unlock() + return s.conn.Close() + } else { + return io.ErrClosedPipe + } +} + +// OpenStream is used to create a new stream for CLIENT +func (s *Session) OpenStream() (*Stream, error) { + if s.IsClosed() { + return nil, io.ErrClosedPipe + } + + sid := s.streamId.Add(1) + stream := newStream(sid, s) + + //logrus.Debugln("stream open", sid, s.streams) + + if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil { + return nil, err + } + + s.buffering = false // proxy Write it's SocksAddr to flush the buffer + + s.streamLock.Lock() + defer s.streamLock.Unlock() + select { + case <-s.die: + return nil, io.ErrClosedPipe + default: + s.streams[sid] = stream + return stream, nil + } +} + +func (s *Session) recvLoop() error { + defer func() { + if r := recover(); r != nil { + log.Errorln("[BUG] %v %s", r, string(debug.Stack())) + } + }() + defer s.Close() + + var receivedSettingsFromClient bool + var hdr rawHeader + + for { + if s.IsClosed() { + return io.ErrClosedPipe + } + // read header first + if _, err := io.ReadFull(s.conn, hdr[:]); err == nil { + sid := hdr.StreamID() + switch hdr.Cmd() { + case cmdPSH: + if hdr.Length() > 0 { + buffer := buf.Get(int(hdr.Length())) + if _, err := io.ReadFull(s.conn, buffer); err == nil { + s.streamLock.RLock() + stream, ok := s.streams[sid] + s.streamLock.RUnlock() + if ok { + stream.pipeW.Write(buffer) + } + buf.Put(buffer) + } else { + buf.Put(buffer) + return err + } + } + case cmdSYN: // should be server only + if !s.isClient && !receivedSettingsFromClient { + f := newFrame(cmdAlert, 0) + f.data = []byte("client did not send its settings") + s.writeFrame(f) + return nil + } + s.streamLock.Lock() + if _, ok := s.streams[sid]; !ok { + stream := newStream(sid, s) + s.streams[sid] = stream + if s.onNewStream != nil { + go s.onNewStream(stream) + } else { + go s.Close() + } + } + s.streamLock.Unlock() + case cmdFIN: + s.streamLock.RLock() + stream, ok := s.streams[sid] + s.streamLock.RUnlock() + if ok { + stream.Close() + } + //logrus.Debugln("stream fin", sid, s.streams) + case cmdWaste: + if hdr.Length() > 0 { + buffer := buf.Get(int(hdr.Length())) + if _, err := io.ReadFull(s.conn, buffer); err != nil { + buf.Put(buffer) + return err + } + buf.Put(buffer) + } + case cmdSettings: + if hdr.Length() > 0 { + buffer := buf.Get(int(hdr.Length())) + if _, err := io.ReadFull(s.conn, buffer); err != nil { + buf.Put(buffer) + return err + } + if !s.isClient { + receivedSettingsFromClient = true + m := util.StringMapFromBytes(buffer) + paddingF := s.padding.Load() + if m["padding-md5"] != paddingF.Md5 { + // logrus.Debugln("remote md5 is", m["padding-md5"]) + f := newFrame(cmdUpdatePaddingScheme, 0) + f.data = paddingF.RawScheme + _, err = s.writeFrame(f) + if err != nil { + buf.Put(buffer) + return err + } + } + } + buf.Put(buffer) + } + case cmdAlert: + if hdr.Length() > 0 { + buffer := buf.Get(int(hdr.Length())) + if _, err := io.ReadFull(s.conn, buffer); err != nil { + buf.Put(buffer) + return err + } + if s.isClient { + log.Errorln("[Alert from server] %s", string(buffer)) + } + buf.Put(buffer) + return nil + } + case cmdUpdatePaddingScheme: + if hdr.Length() > 0 { + buffer := buf.Get(int(hdr.Length())) + if _, err := io.ReadFull(s.conn, buffer); err != nil { + buf.Put(buffer) + return err + } + if s.isClient { + if padding.UpdatePaddingScheme(buffer, s.padding) { + log.Infoln("[Update padding succeed] %x\n", md5.Sum(buffer)) + } else { + log.Warnln("[Update padding failed] %x\n", md5.Sum(buffer)) + } + } + buf.Put(buffer) + } + default: + // I don't know what command it is (can't have data) + } + } else { + return err + } + } +} + +// notify the session that a stream has closed +func (s *Session) streamClosed(sid uint32) error { + _, err := s.writeFrame(newFrame(cmdFIN, sid)) + s.streamLock.Lock() + delete(s.streams, sid) + s.streamLock.Unlock() + return err +} + +func (s *Session) writeFrame(frame frame) (int, error) { + dataLen := len(frame.data) + + buffer := buf.NewSize(dataLen + headerOverHeadSize) + buffer.WriteByte(frame.cmd) + binary.BigEndian.PutUint32(buffer.Extend(4), frame.sid) + binary.BigEndian.PutUint16(buffer.Extend(2), uint16(dataLen)) + buffer.Write(frame.data) + _, err := s.writeConn(buffer.Bytes()) + buffer.Release() + if err != nil { + return 0, err + } + + return dataLen, nil +} + +func (s *Session) writeConn(b []byte) (n int, err error) { + s.connLock.Lock() + defer s.connLock.Unlock() + + if s.buffering { + s.buffer = append(s.buffer, b...) + return len(b), nil + } else if len(s.buffer) > 0 { + b = append(s.buffer, b...) + s.buffer = nil + } + + // calulate & send padding + if s.isClient { + pkt := s.pktCounter.Add(1) + paddingF := s.padding.Load() + if pkt < paddingF.Stop { + pktSizes := paddingF.GenerateRecordPayloadSizes(pkt) + for _, l := range pktSizes { + remainPayloadLen := len(b) + if l == padding.CheckMark { + if remainPayloadLen == 0 { + break + } else { + continue + } + } + // logrus.Debugln(pkt, "write", l, "len", remainPayloadLen, "remain", remainPayloadLen-l) + if remainPayloadLen > l { // this packet is all payload + _, err = s.conn.Write(b[:l]) + if err != nil { + return 0, err + } + n += l + b = b[l:] + } else if remainPayloadLen > 0 { // this packet contains padding and the last part of payload + paddingLen := l - remainPayloadLen + if paddingLen > 0 { + padding := make([]byte, headerOverHeadSize+paddingLen) + padding[0] = cmdWaste + binary.BigEndian.PutUint32(padding[1:5], 0) + binary.BigEndian.PutUint16(padding[5:7], uint16(paddingLen)) + b = append(b, padding...) + } + _, err = s.conn.Write(b) + if err != nil { + return 0, err + } + n += remainPayloadLen + b = nil + } else { // this packet is all padding + padding := make([]byte, headerOverHeadSize+l) + padding[0] = cmdWaste + binary.BigEndian.PutUint32(padding[1:5], 0) + binary.BigEndian.PutUint16(padding[5:7], uint16(l)) + _, err = s.conn.Write(b) + if err != nil { + return 0, err + } + b = nil + } + } + // maybe still remain payload to write + if len(b) == 0 { + return + } + } + } + + return s.conn.Write(b) +} diff --git a/transport/anytls/session/stream.go b/transport/anytls/session/stream.go new file mode 100644 index 00000000..140396e4 --- /dev/null +++ b/transport/anytls/session/stream.go @@ -0,0 +1,99 @@ +package session + +import ( + "io" + "net" + "os" + "sync" + "time" +) + +// Stream implements net.Conn +type Stream struct { + id uint32 + + sess *Session + + pipeR *io.PipeReader + pipeW *io.PipeWriter + + dieOnce sync.Once + dieHook func() +} + +// newStream initiates a Stream struct +func newStream(id uint32, sess *Session) *Stream { + s := new(Stream) + s.id = id + s.sess = sess + s.pipeR, s.pipeW = io.Pipe() + return s +} + +// Read implements net.Conn +func (s *Stream) Read(b []byte) (n int, err error) { + return s.pipeR.Read(b) +} + +// Write implements net.Conn +func (s *Stream) Write(b []byte) (n int, err error) { + f := newFrame(cmdPSH, s.id) + f.data = b + n, err = s.sess.writeFrame(f) + return +} + +// Close implements net.Conn +func (s *Stream) Close() error { + if s.sessionClose() { + // notify remote + return s.sess.streamClosed(s.id) + } else { + return io.ErrClosedPipe + } +} + +// sessionClose close stream from session side, do not notify remote +func (s *Stream) sessionClose() (once bool) { + s.dieOnce.Do(func() { + s.pipeR.Close() + once = true + if s.dieHook != nil { + s.dieHook() + s.dieHook = nil + } + }) + return +} + +func (s *Stream) SetReadDeadline(t time.Time) error { + return os.ErrNotExist +} + +func (s *Stream) SetWriteDeadline(t time.Time) error { + return os.ErrNotExist +} + +func (s *Stream) SetDeadline(t time.Time) error { + return os.ErrNotExist +} + +// LocalAddr satisfies net.Conn interface +func (s *Stream) LocalAddr() net.Addr { + if ts, ok := s.sess.conn.(interface { + LocalAddr() net.Addr + }); ok { + return ts.LocalAddr() + } + return nil +} + +// RemoteAddr satisfies net.Conn interface +func (s *Stream) RemoteAddr() net.Addr { + if ts, ok := s.sess.conn.(interface { + RemoteAddr() net.Addr + }); ok { + return ts.RemoteAddr() + } + return nil +} diff --git a/transport/anytls/skiplist/contianer.go b/transport/anytls/skiplist/contianer.go new file mode 100644 index 00000000..ceda0421 --- /dev/null +++ b/transport/anytls/skiplist/contianer.go @@ -0,0 +1,46 @@ +package skiplist + +// Container is a holder object that stores a collection of other objects. +type Container interface { + IsEmpty() bool // IsEmpty checks if the container has no elements. + Len() int // Len returns the number of elements in the container. + Clear() // Clear erases all elements from the container. After this call, Len() returns zero. +} + +// Map is a associative container that contains key-value pairs with unique keys. +type Map[K any, V any] interface { + Container + Has(K) bool // Checks whether the container contains element with specific key. + Find(K) *V // Finds element with specific key. + Insert(K, V) // Inserts a key-value pair in to the container or replace existing value. + Remove(K) bool // Remove element with specific key. + ForEach(func(K, V)) // Iterate the container. + ForEachIf(func(K, V) bool) // Iterate the container, stops when the callback returns false. + ForEachMutable(func(K, *V)) // Iterate the container, *V is mutable. + ForEachMutableIf(func(K, *V) bool) // Iterate the container, *V is mutable, stops when the callback returns false. +} + +// Set is a containers that store unique elements. +type Set[K any] interface { + Container + Has(K) bool // Checks whether the container contains element with specific key. + Insert(K) // Inserts a key-value pair in to the container or replace existing value. + InsertN(...K) // Inserts multiple key-value pairs in to the container or replace existing value. + Remove(K) bool // Remove element with specific key. + RemoveN(...K) // Remove multiple elements with specific keys. + ForEach(func(K)) // Iterate the container. + ForEachIf(func(K) bool) // Iterate the container, stops when the callback returns false. +} + +// Iterator is the interface for container's iterator. +type Iterator[T any] interface { + IsNotEnd() bool // Whether it is point to the end of the range. + MoveToNext() // Let it point to the next element. + Value() T // Return the value of current element. +} + +// MapIterator is the interface for map's iterator. +type MapIterator[K any, V any] interface { + Iterator[V] + Key() K // The key of the element +} diff --git a/transport/anytls/skiplist/skiplist.go b/transport/anytls/skiplist/skiplist.go new file mode 100644 index 00000000..a4a0ffbb --- /dev/null +++ b/transport/anytls/skiplist/skiplist.go @@ -0,0 +1,455 @@ +package skiplist + +// This implementation is based on https://github.com/liyue201/gostl/tree/master/ds/skiplist +// (many thanks), added many optimizations, such as: +// +// - adaptive level +// - lesser search for prevs when key already exists. +// - reduce memory allocations +// - richer interface. +// +// etc. + +import ( + "math/bits" + "math/rand" + "time" + + "github.com/sagernet/sing/common" +) + +const ( + skipListMaxLevel = 40 +) + +// SkipList is a probabilistic data structure that seem likely to supplant balanced trees as the +// implementation method of choice for many applications. Skip list algorithms have the same +// asymptotic expected time bounds as balanced trees and are simpler, faster and use less space. +// +// See https://en.wikipedia.org/wiki/Skip_list for more details. +type SkipList[K any, V any] struct { + level int // Current level, may increase dynamically during insertion + len int // Total elements numner in the skiplist. + head skipListNode[K, V] // head.next[level] is the head of each level. + // This cache is used to save the previous nodes when modifying the skip list to avoid + // allocating memory each time it is called. + prevsCache []*skipListNode[K, V] + rander *rand.Rand + impl skipListImpl[K, V] +} + +// NewSkipList creates a new SkipList for Ordered key type. +func NewSkipList[K Ordered, V any]() *SkipList[K, V] { + sl := skipListOrdered[K, V]{} + sl.init() + sl.impl = (skipListImpl[K, V])(&sl) + return &sl.SkipList +} + +// NewSkipListFromMap creates a new SkipList from a map. +func NewSkipListFromMap[K Ordered, V any](m map[K]V) *SkipList[K, V] { + sl := NewSkipList[K, V]() + for k, v := range m { + sl.Insert(k, v) + } + return sl +} + +// NewSkipListFunc creates a new SkipList with specified compare function keyCmp. +func NewSkipListFunc[K any, V any](keyCmp CompareFn[K]) *SkipList[K, V] { + sl := skipListFunc[K, V]{} + sl.init() + sl.keyCmp = keyCmp + sl.impl = skipListImpl[K, V](&sl) + return &sl.SkipList +} + +// IsEmpty implements the Container interface. +func (sl *SkipList[K, V]) IsEmpty() bool { + return sl.len == 0 +} + +// Len implements the Container interface. +func (sl *SkipList[K, V]) Len() int { + return sl.len +} + +// Clear implements the Container interface. +func (sl *SkipList[K, V]) Clear() { + for i := range sl.head.next { + sl.head.next[i] = nil + } + sl.level = 1 + sl.len = 0 +} + +// Iterate return an iterator to the skiplist. +func (sl *SkipList[K, V]) Iterate() MapIterator[K, V] { + return &skipListIterator[K, V]{sl.head.next[0], nil} +} + +// Insert inserts a key-value pair into the skiplist. +// If the key is already in the skip list, it's value will be updated. +func (sl *SkipList[K, V]) Insert(key K, value V) { + node, prevs := sl.impl.findInsertPoint(key) + + if node != nil { + // Already exist, update the value + node.value = value + return + } + + level := sl.randomLevel() + node = newSkipListNode(level, key, value) + + for i := 0; i < common.Min(level, sl.level); i++ { + node.next[i] = prevs[i].next[i] + prevs[i].next[i] = node + } + + if level > sl.level { + for i := sl.level; i < level; i++ { + sl.head.next[i] = node + } + sl.level = level + } + + sl.len++ +} + +// Find returns the value associated with the passed key if the key is in the skiplist, otherwise +// returns nil. +func (sl *SkipList[K, V]) Find(key K) *V { + node := sl.impl.findNode(key) + if node != nil { + return &node.value + } + return nil +} + +// Has implement the Map interface. +func (sl *SkipList[K, V]) Has(key K) bool { + return sl.impl.findNode(key) != nil +} + +// LowerBound returns an iterator to the first element in the skiplist that +// does not satisfy element < value (i.e. greater or equal to), +// or a end itetator if no such element is found. +func (sl *SkipList[K, V]) LowerBound(key K) MapIterator[K, V] { + return &skipListIterator[K, V]{sl.impl.lowerBound(key), nil} +} + +// UpperBound returns an iterator to the first element in the skiplist that +// does not satisfy value < element (i.e. strictly greater), +// or a end itetator if no such element is found. +func (sl *SkipList[K, V]) UpperBound(key K) MapIterator[K, V] { + return &skipListIterator[K, V]{sl.impl.upperBound(key), nil} +} + +// FindRange returns an iterator in range [first, last) (last is not includeed). +func (sl *SkipList[K, V]) FindRange(first, last K) MapIterator[K, V] { + return &skipListIterator[K, V]{sl.impl.lowerBound(first), sl.impl.upperBound(last)} +} + +// Remove removes the key-value pair associated with the passed key and returns true if the key is +// in the skiplist, otherwise returns false. +func (sl *SkipList[K, V]) Remove(key K) bool { + node, prevs := sl.impl.findRemovePoint(key) + if node == nil { + return false + } + for i, v := range node.next { + prevs[i].next[i] = v + } + for sl.level > 1 && sl.head.next[sl.level-1] == nil { + sl.level-- + } + sl.len-- + return true +} + +// ForEach implements the Map interface. +func (sl *SkipList[K, V]) ForEach(op func(K, V)) { + for e := sl.head.next[0]; e != nil; e = e.next[0] { + op(e.key, e.value) + } +} + +// ForEachMutable implements the Map interface. +func (sl *SkipList[K, V]) ForEachMutable(op func(K, *V)) { + for e := sl.head.next[0]; e != nil; e = e.next[0] { + op(e.key, &e.value) + } +} + +// ForEachIf implements the Map interface. +func (sl *SkipList[K, V]) ForEachIf(op func(K, V) bool) { + for e := sl.head.next[0]; e != nil; e = e.next[0] { + if !op(e.key, e.value) { + return + } + } +} + +// ForEachMutableIf implements the Map interface. +func (sl *SkipList[K, V]) ForEachMutableIf(op func(K, *V) bool) { + for e := sl.head.next[0]; e != nil; e = e.next[0] { + if !op(e.key, &e.value) { + return + } + } +} + +/// SkipList implementation part. + +type skipListNode[K any, V any] struct { + key K + value V + next []*skipListNode[K, V] +} + +//go:generate bash ./skiplist_newnode_generate.sh skipListMaxLevel skiplist_newnode.go +// func newSkipListNode[K Ordered, V any](level int, key K, value V) *skipListNode[K, V] + +type skipListIterator[K any, V any] struct { + node, end *skipListNode[K, V] +} + +func (it *skipListIterator[K, V]) IsNotEnd() bool { + return it.node != it.end +} + +func (it *skipListIterator[K, V]) MoveToNext() { + it.node = it.node.next[0] +} + +func (it *skipListIterator[K, V]) Key() K { + return it.node.key +} + +func (it *skipListIterator[K, V]) Value() V { + return it.node.value +} + +// skipListImpl is an interface to provide different implementation for Ordered key or CompareFn. +// +// We can use CompareFn to cumpare Ordered keys, but a separated implementation is much faster. +// We don't make the whole skip list an interface, in order to share the type independented method. +// And because these methods are called directly without going through the interface, they are also +// much faster. +type skipListImpl[K any, V any] interface { + findNode(key K) *skipListNode[K, V] + lowerBound(key K) *skipListNode[K, V] + upperBound(key K) *skipListNode[K, V] + findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) + findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) +} + +func (sl *SkipList[K, V]) init() { + sl.level = 1 + // #nosec G404 -- This is not a security condition + sl.rander = rand.New(rand.NewSource(time.Now().Unix())) + sl.prevsCache = make([]*skipListNode[K, V], skipListMaxLevel) + sl.head.next = make([]*skipListNode[K, V], skipListMaxLevel) +} + +func (sl *SkipList[K, V]) randomLevel() int { + total := uint64(1)< 3 && 1<<(level-3) > sl.len { + level-- + } + + return level +} + +/// skipListOrdered part + +// skipListOrdered is the skip list implementation for Ordered types. +type skipListOrdered[K Ordered, V any] struct { + SkipList[K, V] +} + +func (sl *skipListOrdered[K, V]) findNode(key K) *skipListNode[K, V] { + return sl.doFindNode(key, true) +} + +func (sl *skipListOrdered[K, V]) doFindNode(key K, eq bool) *skipListNode[K, V] { + // This function execute the job of findNode if eq is true, otherwise lowBound. + // Passing the control variable eq is ugly but it's faster than testing node + // again outside the function in findNode. + prev := &sl.head + for i := sl.level - 1; i >= 0; i-- { + for cur := prev.next[i]; cur != nil; cur = cur.next[i] { + if cur.key == key { + return cur + } + if cur.key > key { + // All other node in this level must be greater than the key, + // search the next level. + break + } + prev = cur + } + } + if eq { + return nil + } + return prev.next[0] +} + +func (sl *skipListOrdered[K, V]) lowerBound(key K) *skipListNode[K, V] { + return sl.doFindNode(key, false) +} + +func (sl *skipListOrdered[K, V]) upperBound(key K) *skipListNode[K, V] { + node := sl.lowerBound(key) + if node != nil && node.key == key { + return node.next[0] + } + return node +} + +// findInsertPoint returns (*node, nil) to the existed node if the key exists, +// or (nil, []*node) to the previous nodes if the key doesn't exist +func (sl *skipListOrdered[K, V]) findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) { + prevs := sl.prevsCache[0:sl.level] + prev := &sl.head + for i := sl.level - 1; i >= 0; i-- { + for next := prev.next[i]; next != nil; next = next.next[i] { + if next.key == key { + // The key is already existed, prevs are useless because no new node insertion. + // stop searching. + return next, nil + } + if next.key > key { + // All other node in this level must be greater than the key, + // search the next level. + break + } + prev = next + } + prevs[i] = prev + } + return nil, prevs +} + +// findRemovePoint finds the node which match the key and it's previous nodes. +func (sl *skipListOrdered[K, V]) findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) { + prevs := sl.findPrevNodes(key) + node := prevs[0].next[0] + if node == nil || node.key != key { + return nil, nil + } + return node, prevs +} + +func (sl *skipListOrdered[K, V]) findPrevNodes(key K) []*skipListNode[K, V] { + prevs := sl.prevsCache[0:sl.level] + prev := &sl.head + for i := sl.level - 1; i >= 0; i-- { + for next := prev.next[i]; next != nil; next = next.next[i] { + if next.key >= key { + break + } + prev = next + } + prevs[i] = prev + } + return prevs +} + +/// skipListFunc part + +// skipListFunc is the skip list implementation which compare keys with func. +type skipListFunc[K any, V any] struct { + SkipList[K, V] + keyCmp CompareFn[K] +} + +func (sl *skipListFunc[K, V]) findNode(key K) *skipListNode[K, V] { + node := sl.lowerBound(key) + if node != nil && sl.keyCmp(node.key, key) == 0 { + return node + } + return nil +} + +func (sl *skipListFunc[K, V]) lowerBound(key K) *skipListNode[K, V] { + var prev = &sl.head + for i := sl.level - 1; i >= 0; i-- { + cur := prev.next[i] + for ; cur != nil; cur = cur.next[i] { + cmpRet := sl.keyCmp(cur.key, key) + if cmpRet == 0 { + return cur + } + if cmpRet > 0 { + break + } + prev = cur + } + } + return prev.next[0] +} + +func (sl *skipListFunc[K, V]) upperBound(key K) *skipListNode[K, V] { + node := sl.lowerBound(key) + if node != nil && sl.keyCmp(node.key, key) == 0 { + return node.next[0] + } + return node +} + +// findInsertPoint returns (*node, nil) to the existed node if the key exists, +// or (nil, []*node) to the previous nodes if the key doesn't exist +func (sl *skipListFunc[K, V]) findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) { + prevs := sl.prevsCache[0:sl.level] + prev := &sl.head + for i := sl.level - 1; i >= 0; i-- { + for cur := prev.next[i]; cur != nil; cur = cur.next[i] { + r := sl.keyCmp(cur.key, key) + if r == 0 { + // The key is already existed, prevs are useless because no new node insertion. + // stop searching. + return cur, nil + } + if r > 0 { + // All other node in this level must be greater than the key, + // search the next level. + break + } + prev = cur + } + prevs[i] = prev + } + return nil, prevs +} + +// findRemovePoint finds the node which match the key and it's previous nodes. +func (sl *skipListFunc[K, V]) findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) { + prevs := sl.findPrevNodes(key) + node := prevs[0].next[0] + if node == nil || sl.keyCmp(node.key, key) != 0 { + return nil, nil + } + return node, prevs +} + +func (sl *skipListFunc[K, V]) findPrevNodes(key K) []*skipListNode[K, V] { + prevs := sl.prevsCache[0:sl.level] + prev := &sl.head + for i := sl.level - 1; i >= 0; i-- { + for next := prev.next[i]; next != nil; next = next.next[i] { + if sl.keyCmp(next.key, key) >= 0 { + break + } + prev = next + } + prevs[i] = prev + } + return prevs +} diff --git a/transport/anytls/skiplist/skiplist_newnode.go b/transport/anytls/skiplist/skiplist_newnode.go new file mode 100644 index 00000000..4e8a6d88 --- /dev/null +++ b/transport/anytls/skiplist/skiplist_newnode.go @@ -0,0 +1,297 @@ +// AUTO GENERATED CODE, DON'T EDIT!!! +// EDIT skiplist_newnode_generate.sh accordingly. + +package skiplist + +// newSkipListNode creates a new node initialized with specified key, value and next slice. +func newSkipListNode[K any, V any](level int, key K, value V) *skipListNode[K, V] { + // For nodes with each levels, point their next slice to the nexts array allocated together, + // which can reduce 1 memory allocation and improve performance. + // + // The generics of the golang doesn't support non-type parameters like in C++, + // so we have to generate it manually. + switch level { + case 1: + n := struct { + head skipListNode[K, V] + nexts [1]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 2: + n := struct { + head skipListNode[K, V] + nexts [2]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 3: + n := struct { + head skipListNode[K, V] + nexts [3]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 4: + n := struct { + head skipListNode[K, V] + nexts [4]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 5: + n := struct { + head skipListNode[K, V] + nexts [5]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 6: + n := struct { + head skipListNode[K, V] + nexts [6]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 7: + n := struct { + head skipListNode[K, V] + nexts [7]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 8: + n := struct { + head skipListNode[K, V] + nexts [8]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 9: + n := struct { + head skipListNode[K, V] + nexts [9]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 10: + n := struct { + head skipListNode[K, V] + nexts [10]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 11: + n := struct { + head skipListNode[K, V] + nexts [11]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 12: + n := struct { + head skipListNode[K, V] + nexts [12]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 13: + n := struct { + head skipListNode[K, V] + nexts [13]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 14: + n := struct { + head skipListNode[K, V] + nexts [14]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 15: + n := struct { + head skipListNode[K, V] + nexts [15]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 16: + n := struct { + head skipListNode[K, V] + nexts [16]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 17: + n := struct { + head skipListNode[K, V] + nexts [17]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 18: + n := struct { + head skipListNode[K, V] + nexts [18]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 19: + n := struct { + head skipListNode[K, V] + nexts [19]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 20: + n := struct { + head skipListNode[K, V] + nexts [20]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 21: + n := struct { + head skipListNode[K, V] + nexts [21]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 22: + n := struct { + head skipListNode[K, V] + nexts [22]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 23: + n := struct { + head skipListNode[K, V] + nexts [23]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 24: + n := struct { + head skipListNode[K, V] + nexts [24]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 25: + n := struct { + head skipListNode[K, V] + nexts [25]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 26: + n := struct { + head skipListNode[K, V] + nexts [26]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 27: + n := struct { + head skipListNode[K, V] + nexts [27]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 28: + n := struct { + head skipListNode[K, V] + nexts [28]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 29: + n := struct { + head skipListNode[K, V] + nexts [29]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 30: + n := struct { + head skipListNode[K, V] + nexts [30]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 31: + n := struct { + head skipListNode[K, V] + nexts [31]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 32: + n := struct { + head skipListNode[K, V] + nexts [32]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 33: + n := struct { + head skipListNode[K, V] + nexts [33]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 34: + n := struct { + head skipListNode[K, V] + nexts [34]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 35: + n := struct { + head skipListNode[K, V] + nexts [35]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 36: + n := struct { + head skipListNode[K, V] + nexts [36]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 37: + n := struct { + head skipListNode[K, V] + nexts [37]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 38: + n := struct { + head skipListNode[K, V] + nexts [38]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 39: + n := struct { + head skipListNode[K, V] + nexts [39]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 40: + n := struct { + head skipListNode[K, V] + nexts [40]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + } + + panic("should not reach here") +} diff --git a/transport/anytls/skiplist/types.go b/transport/anytls/skiplist/types.go new file mode 100644 index 00000000..c534f460 --- /dev/null +++ b/transport/anytls/skiplist/types.go @@ -0,0 +1,75 @@ +package skiplist + +// Signed is a constraint that permits any signed integer type. +// If future releases of Go add new predeclared signed integer types, +// this constraint will be modified to include them. +type Signed interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +// Unsigned is a constraint that permits any unsigned integer type. +// If future releases of Go add new predeclared unsigned integer types, +// this constraint will be modified to include them. +type Unsigned interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr +} + +// Integer is a constraint that permits any integer type. +// If future releases of Go add new predeclared integer types, +// this constraint will be modified to include them. +type Integer interface { + Signed | Unsigned +} + +// Float is a constraint that permits any floating-point type. +// If future releases of Go add new predeclared floating-point types, +// this constraint will be modified to include them. +type Float interface { + ~float32 | ~float64 +} + +// Ordered is a constraint that permits any ordered type: any type +// that supports the operators < <= >= >. +// If future releases of Go add new ordered types, +// this constraint will be modified to include them. +type Ordered interface { + Integer | Float | ~string +} + +// Numeric is a constraint that permits any numeric type. +type Numeric interface { + Integer | Float +} + +// LessFn is a function that returns whether 'a' is less than 'b'. +type LessFn[T any] func(a, b T) bool + +// CompareFn is a 3 way compare function that +// returns 1 if a > b, +// returns 0 if a == b, +// returns -1 if a < b. +type CompareFn[T any] func(a, b T) int + +// HashFn is a function that returns the hash of 't'. +type HashFn[T any] func(t T) uint64 + +// Equals wraps the '==' operator for comparable types. +func Equals[T comparable](a, b T) bool { + return a == b +} + +// Less wraps the '<' operator for ordered types. +func Less[T Ordered](a, b T) bool { + return a < b +} + +// OrderedCompare provide default CompareFn for ordered types. +func OrderedCompare[T Ordered](a, b T) int { + if a < b { + return -1 + } + if a > b { + return 1 + } + return 0 +} diff --git a/transport/anytls/util/routine.go b/transport/anytls/util/routine.go new file mode 100644 index 00000000..4fdfbdd4 --- /dev/null +++ b/transport/anytls/util/routine.go @@ -0,0 +1,28 @@ +package util + +import ( + "context" + "runtime/debug" + "time" + + "github.com/metacubex/mihomo/log" +) + +func StartRoutine(ctx context.Context, d time.Duration, f func()) { + go func() { + defer func() { + if r := recover(); r != nil { + log.Errorln("[BUG] %v %s", r, string(debug.Stack())) + } + }() + for { + time.Sleep(d) + f() + select { + case <-ctx.Done(): + return + default: + } + } + }() +} diff --git a/transport/anytls/util/string_map.go b/transport/anytls/util/string_map.go new file mode 100644 index 00000000..27fb3581 --- /dev/null +++ b/transport/anytls/util/string_map.go @@ -0,0 +1,27 @@ +package util + +import ( + "strings" +) + +type StringMap map[string]string + +func (s StringMap) ToBytes() []byte { + var lines []string + for k, v := range s { + lines = append(lines, k+"="+v) + } + return []byte(strings.Join(lines, "\n")) +} + +func StringMapFromBytes(b []byte) StringMap { + var m = make(StringMap) + var lines = strings.Split(string(b), "\n") + for _, line := range lines { + v := strings.SplitN(line, "=", 2) + if len(v) == 2 { + m[v[0]] = v[1] + } + } + return m +}