From 5702d28cda2ba33d13dc1c9085b0203f75198285 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 8 Mar 2024 19:27:41 +0800 Subject: [PATCH] chore: rebuild ssh outbound --- adapter/outbound/ssh.go | 103 ++++++++++++++++++++++++++++++++++------ 1 file changed, 89 insertions(+), 14 deletions(-) diff --git a/adapter/outbound/ssh.go b/adapter/outbound/ssh.go index 140a9331..a41a8132 100644 --- a/adapter/outbound/ssh.go +++ b/adapter/outbound/ssh.go @@ -6,10 +6,14 @@ import ( "os" "runtime" "strconv" + "sync" - CN "github.com/metacubex/mihomo/common/net" + N "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/component/dialer" + "github.com/metacubex/mihomo/component/proxydialer" C "github.com/metacubex/mihomo/constant" + + "github.com/zhangyunhao116/fastrand" "golang.org/x/crypto/ssh" ) @@ -17,7 +21,7 @@ type Ssh struct { *Base option *SshOption - client *ssh.Client + client *sshClient // using a standalone struct to avoid its inner loop invalidate the Finalizer } type SshOption struct { @@ -30,18 +34,85 @@ type SshOption struct { PrivateKey string `proxy:"privateKey,omitempty"` } -func (h *Ssh) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { - c, err := h.client.Dial("tcp", metadata.RemoteAddress()) +func (s *Ssh) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { + var cDialer C.Dialer = dialer.NewDialer(s.Base.DialOptions(opts...)...) + if len(s.option.DialerProxy) > 0 { + cDialer, err = proxydialer.NewByName(s.option.DialerProxy, cDialer) + if err != nil { + return nil, err + } + } + client, err := s.client.connect(ctx, cDialer, s.addr) if err != nil { return nil, err } - return NewConn(CN.NewRefConn(c, h), h), nil + c, err := client.DialContext(ctx, "tcp", metadata.RemoteAddress()) + if err != nil { + return nil, err + } + + return NewConn(N.NewRefConn(c, s), s), nil } -func closeSsh(h *Ssh) { - if h.client != nil { - _ = h.client.Close() +type sshClient struct { + config *ssh.ClientConfig + client *ssh.Client + cMutex sync.Mutex +} + +func (s *sshClient) connect(ctx context.Context, cDialer C.Dialer, addr string) (client *ssh.Client, err error) { + s.cMutex.Lock() + defer s.cMutex.Unlock() + if s.client != nil { + return s.client, nil } + c, err := cDialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + N.TCPKeepAlive(c) + + defer func(c net.Conn) { + safeConnClose(c, err) + }(c) + + if ctx.Done() != nil { + done := N.SetupContextForConn(ctx, c) + defer done(&err) + } + + clientConn, chans, reqs, err := ssh.NewClientConn(c, addr, s.config) + if err != nil { + return nil, err + } + client = ssh.NewClient(clientConn, chans, reqs) + + s.client = client + + go func() { + _ = client.Wait() // wait shutdown + _ = client.Close() + s.cMutex.Lock() + defer s.cMutex.Unlock() + if s.client == client { + s.client = nil + } + }() + + return client, nil +} + +func (s *sshClient) Close() error { + s.cMutex.Lock() + defer s.cMutex.Unlock() + if s.client != nil { + return s.client.Close() + } + return nil +} + +func closeSsh(s *Ssh) { + _ = s.client.Close() } func NewSsh(option SshOption) (*Ssh, error) { @@ -55,7 +126,6 @@ func NewSsh(option SshOption) (*Ssh, error) { } if option.Password == "" { - b, err := os.ReadFile(option.PrivateKey) if err != nil { return nil, err @@ -74,23 +144,28 @@ func NewSsh(option SshOption) (*Ssh, error) { } } - client, err := ssh.Dial("tcp", addr, &config) - if err != nil { - return nil, err + version := "SSH-2.0-OpenSSH_" + if fastrand.Intn(2) == 0 { + version += "7." + strconv.Itoa(fastrand.Intn(10)) + } else { + version += "8." + strconv.Itoa(fastrand.Intn(9)) } + config.ClientVersion = version outbound := &Ssh{ Base: &Base{ name: option.Name, addr: addr, tp: C.Ssh, - udp: true, + udp: false, iface: option.Interface, rmark: option.RoutingMark, prefer: C.NewDNSPrefer(option.IPVersion), }, option: &option, - client: client, + client: &sshClient{ + config: &config, + }, } runtime.SetFinalizer(outbound, closeSsh)