From 808fdcf624648b80a74cc7014374d3781364d750 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Mon, 17 Feb 2025 19:43:58 +0800 Subject: [PATCH] chore: code cleanup --- adapter/outbound/anytls.go | 4 +-- listener/anytls/server.go | 3 +- transport/anytls/client.go | 8 ++--- transport/anytls/padding/padding.go | 8 ++--- transport/anytls/session/client.go | 10 +++--- transport/anytls/session/session.go | 44 +++++++++++++-------------- transport/anytls/skiplist/skiplist.go | 8 +++-- transport/trojan/trojan.go | 2 +- transport/vmess/tls.go | 4 +-- 9 files changed, 46 insertions(+), 45 deletions(-) diff --git a/adapter/outbound/anytls.go b/adapter/outbound/anytls.go index 8af33f20..a73b3005 100644 --- a/adapter/outbound/anytls.go +++ b/adapter/outbound/anytls.go @@ -16,7 +16,7 @@ import ( tlsC "github.com/metacubex/mihomo/component/tls" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/transport/anytls" - "github.com/sagernet/sing/common" + M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/uot" ) @@ -130,7 +130,7 @@ func NewAnyTLS(option AnyTLSOption) (*AnyTLS, error) { dialer: singDialer, } runtime.SetFinalizer(outbound, func(o *AnyTLS) { - common.Close(o.client) + _ = o.client.Close() }) return outbound, nil diff --git a/listener/anytls/server.go b/listener/anytls/server.go index 5d860e8a..f89f2277 100644 --- a/listener/anytls/server.go +++ b/listener/anytls/server.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/metacubex/mihomo/adapter/inbound" + "github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/buf" N "github.com/metacubex/mihomo/common/net" C "github.com/metacubex/mihomo/constant" @@ -17,7 +18,7 @@ import ( "github.com/metacubex/mihomo/listener/sing" "github.com/metacubex/mihomo/transport/anytls/padding" "github.com/metacubex/mihomo/transport/anytls/session" - "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" diff --git a/transport/anytls/client.go b/transport/anytls/client.go index 2076019e..8fe65b6f 100644 --- a/transport/anytls/client.go +++ b/transport/anytls/client.go @@ -8,13 +8,13 @@ import ( "net" "time" - tlsC "github.com/metacubex/mihomo/component/tls" + "github.com/metacubex/mihomo/common/atomic" + "github.com/metacubex/mihomo/common/buf" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/transport/anytls/padding" "github.com/metacubex/mihomo/transport/anytls/session" "github.com/metacubex/mihomo/transport/vmess" - "github.com/sagernet/sing/common/atomic" - "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) @@ -91,7 +91,7 @@ func (c *Client) CreateOutboundTLSConnection(ctx context.Context) (net.Conn, err ctx, cancel := context.WithTimeout(ctx, C.DefaultTLSTimeout) defer cancel() - err := utlsConn.(*tlsC.UConn).HandshakeContext(ctx) + err := utlsConn.HandshakeContext(ctx) return utlsConn, err } } diff --git a/transport/anytls/padding/padding.go b/transport/anytls/padding/padding.go index e881e573..addd47c2 100644 --- a/transport/anytls/padding/padding.go +++ b/transport/anytls/padding/padding.go @@ -8,10 +8,8 @@ import ( "strconv" "strings" + "github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/transport/anytls/util" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/atomic" ) const CheckMark = -1 @@ -73,7 +71,9 @@ func (p *PaddingFactory) GenerateRecordPayloadSizes(pkt uint32) (pktSizes []int) if err != nil { continue } - _min, _max = common.Min(_min, _max), common.Max(_min, _max) + if _min > _max { + _min, _max = _max, _min + } if _min <= 0 || _max <= 0 { continue } diff --git a/transport/anytls/session/client.go b/transport/anytls/session/client.go index 5a853478..869a4b31 100644 --- a/transport/anytls/session/client.go +++ b/transport/anytls/session/client.go @@ -7,14 +7,12 @@ import ( "math" "net" "sync" - "sync/atomic" "time" + "github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/transport/anytls/padding" "github.com/metacubex/mihomo/transport/anytls/skiplist" "github.com/metacubex/mihomo/transport/anytls/util" - "github.com/sagernet/sing/common" - singAtomic "github.com/sagernet/sing/common/atomic" ) type Client struct { @@ -27,12 +25,12 @@ type Client struct { idleSession *skiplist.SkipList[uint64, *Session] idleSessionLock sync.Mutex - padding *singAtomic.TypedValue[*padding.PaddingFactory] + padding *atomic.TypedValue[*padding.PaddingFactory] idleSessionTimeout time.Duration } -func NewClient(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error), _padding *singAtomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration) *Client { +func NewClient(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error), _padding *atomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration) *Client { c := &Client{ dialOut: dialOut, padding: _padding, @@ -68,7 +66,7 @@ func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) { } stream, err = session.OpenStream() if err != nil { - common.Close(session, stream) + _ = session.Close() continue } break diff --git a/transport/anytls/session/session.go b/transport/anytls/session/session.go index e7186b63..145eb032 100644 --- a/transport/anytls/session/session.go +++ b/transport/anytls/session/session.go @@ -7,15 +7,15 @@ import ( "net" "runtime/debug" "sync" - "sync/atomic" "time" + "github.com/metacubex/mihomo/common/atomic" + "github.com/metacubex/mihomo/common/buf" + "github.com/metacubex/mihomo/common/pool" "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/transport/anytls/padding" "github.com/metacubex/mihomo/transport/anytls/util" - singAtomic "github.com/sagernet/sing/common/atomic" - "github.com/sagernet/sing/common/buf" ) type Session struct { @@ -33,7 +33,7 @@ type Session struct { // pool seq uint64 idleSince time.Time - padding *singAtomic.TypedValue[*padding.PaddingFactory] + padding *atomic.TypedValue[*padding.PaddingFactory] // client isClient bool @@ -45,7 +45,7 @@ type Session struct { onNewStream func(stream *Stream) } -func NewClientSession(conn net.Conn, _padding *singAtomic.TypedValue[*padding.PaddingFactory]) *Session { +func NewClientSession(conn net.Conn, _padding *atomic.TypedValue[*padding.PaddingFactory]) *Session { s := &Session{ conn: conn, isClient: true, @@ -56,7 +56,7 @@ func NewClientSession(conn net.Conn, _padding *singAtomic.TypedValue[*padding.Pa return s } -func NewServerSession(conn net.Conn, onNewStream func(stream *Stream), _padding *singAtomic.TypedValue[*padding.PaddingFactory]) *Session { +func NewServerSession(conn net.Conn, onNewStream func(stream *Stream), _padding *atomic.TypedValue[*padding.PaddingFactory]) *Session { s := &Session{ conn: conn, onNewStream: onNewStream, @@ -169,7 +169,7 @@ func (s *Session) recvLoop() error { switch hdr.Cmd() { case cmdPSH: if hdr.Length() > 0 { - buffer := buf.Get(int(hdr.Length())) + buffer := pool.Get(int(hdr.Length())) if _, err := io.ReadFull(s.conn, buffer); err == nil { s.streamLock.RLock() stream, ok := s.streams[sid] @@ -177,9 +177,9 @@ func (s *Session) recvLoop() error { if ok { stream.pipeW.Write(buffer) } - buf.Put(buffer) + pool.Put(buffer) } else { - buf.Put(buffer) + pool.Put(buffer) return err } } @@ -211,18 +211,18 @@ func (s *Session) recvLoop() error { //logrus.Debugln("stream fin", sid, s.streams) case cmdWaste: if hdr.Length() > 0 { - buffer := buf.Get(int(hdr.Length())) + buffer := pool.Get(int(hdr.Length())) if _, err := io.ReadFull(s.conn, buffer); err != nil { - buf.Put(buffer) + pool.Put(buffer) return err } - buf.Put(buffer) + pool.Put(buffer) } case cmdSettings: if hdr.Length() > 0 { - buffer := buf.Get(int(hdr.Length())) + buffer := pool.Get(int(hdr.Length())) if _, err := io.ReadFull(s.conn, buffer); err != nil { - buf.Put(buffer) + pool.Put(buffer) return err } if !s.isClient { @@ -235,31 +235,31 @@ func (s *Session) recvLoop() error { f.data = paddingF.RawScheme _, err = s.writeFrame(f) if err != nil { - buf.Put(buffer) + pool.Put(buffer) return err } } } - buf.Put(buffer) + pool.Put(buffer) } case cmdAlert: if hdr.Length() > 0 { - buffer := buf.Get(int(hdr.Length())) + buffer := pool.Get(int(hdr.Length())) if _, err := io.ReadFull(s.conn, buffer); err != nil { - buf.Put(buffer) + pool.Put(buffer) return err } if s.isClient { log.Errorln("[Alert from server] %s", string(buffer)) } - buf.Put(buffer) + pool.Put(buffer) return nil } case cmdUpdatePaddingScheme: if hdr.Length() > 0 { - buffer := buf.Get(int(hdr.Length())) + buffer := pool.Get(int(hdr.Length())) if _, err := io.ReadFull(s.conn, buffer); err != nil { - buf.Put(buffer) + pool.Put(buffer) return err } if s.isClient { @@ -269,7 +269,7 @@ func (s *Session) recvLoop() error { log.Warnln("[Update padding failed] %x\n", md5.Sum(buffer)) } } - buf.Put(buffer) + pool.Put(buffer) } default: // I don't know what command it is (can't have data) diff --git a/transport/anytls/skiplist/skiplist.go b/transport/anytls/skiplist/skiplist.go index a4a0ffbb..f1ce402a 100644 --- a/transport/anytls/skiplist/skiplist.go +++ b/transport/anytls/skiplist/skiplist.go @@ -14,8 +14,6 @@ import ( "math/bits" "math/rand" "time" - - "github.com/sagernet/sing/common" ) const ( @@ -102,7 +100,11 @@ func (sl *SkipList[K, V]) Insert(key K, value V) { level := sl.randomLevel() node = newSkipListNode(level, key, value) - for i := 0; i < common.Min(level, sl.level); i++ { + minLevel := level + if sl.level < level { + minLevel = sl.level + } + for i := 0; i < minLevel; i++ { node.next[i] = prevs[i].next[i] prevs[i].next[i] = node } diff --git a/transport/trojan/trojan.go b/transport/trojan/trojan.go index 17f403c1..c1ebb8da 100644 --- a/transport/trojan/trojan.go +++ b/transport/trojan/trojan.go @@ -93,7 +93,7 @@ func (t *Trojan) StreamConn(ctx context.Context, conn net.Conn) (net.Conn, error ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) defer cancel() - err := utlsConn.(*tlsC.UConn).HandshakeContext(ctx) + err := utlsConn.HandshakeContext(ctx) return utlsConn, err } } else { diff --git a/transport/vmess/tls.go b/transport/vmess/tls.go index bdaa8ccc..82a484f1 100644 --- a/transport/vmess/tls.go +++ b/transport/vmess/tls.go @@ -36,7 +36,7 @@ func StreamTLSConn(ctx context.Context, conn net.Conn, cfg *TLSConfig) (net.Conn if cfg.Reality == nil { utlsConn, valid := GetUTLSConn(conn, cfg.ClientFingerprint, tlsConfig) if valid { - err := utlsConn.(*tlsC.UConn).HandshakeContext(ctx) + err = utlsConn.HandshakeContext(ctx) return utlsConn, err } } else { @@ -53,7 +53,7 @@ func StreamTLSConn(ctx context.Context, conn net.Conn, cfg *TLSConfig) (net.Conn return tlsConn, err } -func GetUTLSConn(conn net.Conn, ClientFingerprint string, tlsConfig *tls.Config) (net.Conn, bool) { +func GetUTLSConn(conn net.Conn, ClientFingerprint string, tlsConfig *tls.Config) (*tlsC.UConn, bool) { if fingerprint, exists := tlsC.GetFingerprint(ClientFingerprint); exists { utlsConn := tlsC.UClient(conn, tlsConfig, fingerprint)