From e2b75b35bbfd3be61b2b5a37142442bb259f99e4 Mon Sep 17 00:00:00 2001 From: anytls Date: Wed, 19 Feb 2025 15:54:56 +0800 Subject: [PATCH] chore: update anytls (#1851) * Implement deadline for `Stream` * chore: code cleanup * fix: buffer release * fix: do not use buffer for `cmdUpdatePaddingScheme` --------- Co-authored-by: anytls --- listener/anytls/server.go | 4 +- transport/anytls/client.go | 2 + transport/anytls/pipe/deadline.go | 74 +++++++++ transport/anytls/pipe/io_pipe.go | 232 ++++++++++++++++++++++++++++ transport/anytls/session/client.go | 2 +- transport/anytls/session/session.go | 42 ++--- transport/anytls/session/stream.go | 23 ++- 7 files changed, 352 insertions(+), 27 deletions(-) create mode 100644 transport/anytls/pipe/deadline.go create mode 100644 transport/anytls/pipe/io_pipe.go diff --git a/listener/anytls/server.go b/listener/anytls/server.go index f89f2277..31a7c55a 100644 --- a/listener/anytls/server.go +++ b/listener/anytls/server.go @@ -135,6 +135,8 @@ func (l *Listener) HandleConn(conn net.Conn, h *sing.ListenerHandler) { defer conn.Close() b := buf.NewPacket() + defer b.Release() + _, err := b.ReadOnceFrom(conn) if err != nil { return @@ -177,6 +179,6 @@ func (l *Listener) HandleConn(conn net.Conn, h *sing.ListenerHandler) { Destination: destination, }) }, &l.padding) - session.Run(true) + session.Run() session.Close() } diff --git a/transport/anytls/client.go b/transport/anytls/client.go index fd573ebe..19776df9 100644 --- a/transport/anytls/client.go +++ b/transport/anytls/client.go @@ -71,6 +71,8 @@ func (c *Client) CreateOutboundTLSConnection(ctx context.Context) (net.Conn, err } b := buf.NewPacket() + defer b.Release() + b.Write(c.passwordSha256) var paddingLen int if pad := c.padding.Load().GenerateRecordPayloadSizes(0); len(pad) > 0 { diff --git a/transport/anytls/pipe/deadline.go b/transport/anytls/pipe/deadline.go new file mode 100644 index 00000000..29c4ec0a --- /dev/null +++ b/transport/anytls/pipe/deadline.go @@ -0,0 +1,74 @@ +package pipe + +import ( + "sync" + "time" +) + +// PipeDeadline is an abstraction for handling timeouts. +type PipeDeadline struct { + mu sync.Mutex // Guards timer and cancel + timer *time.Timer + cancel chan struct{} // Must be non-nil +} + +func MakePipeDeadline() PipeDeadline { + return PipeDeadline{cancel: make(chan struct{})} +} + +// Set sets the point in time when the deadline will time out. +// A timeout event is signaled by closing the channel returned by waiter. +// Once a timeout has occurred, the deadline can be refreshed by specifying a +// t value in the future. +// +// A zero value for t prevents timeout. +func (d *PipeDeadline) Set(t time.Time) { + d.mu.Lock() + defer d.mu.Unlock() + + if d.timer != nil && !d.timer.Stop() { + <-d.cancel // Wait for the timer callback to finish and close cancel + } + d.timer = nil + + // Time is zero, then there is no deadline. + closed := isClosedChan(d.cancel) + if t.IsZero() { + if closed { + d.cancel = make(chan struct{}) + } + return + } + + // Time in the future, setup a timer to cancel in the future. + if dur := time.Until(t); dur > 0 { + if closed { + d.cancel = make(chan struct{}) + } + d.timer = time.AfterFunc(dur, func() { + close(d.cancel) + }) + return + } + + // Time in the past, so close immediately. + if !closed { + close(d.cancel) + } +} + +// Wait returns a channel that is closed when the deadline is exceeded. +func (d *PipeDeadline) Wait() chan struct{} { + d.mu.Lock() + defer d.mu.Unlock() + return d.cancel +} + +func isClosedChan(c <-chan struct{}) bool { + select { + case <-c: + return true + default: + return false + } +} diff --git a/transport/anytls/pipe/io_pipe.go b/transport/anytls/pipe/io_pipe.go new file mode 100644 index 00000000..5d0fd252 --- /dev/null +++ b/transport/anytls/pipe/io_pipe.go @@ -0,0 +1,232 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Pipe adapter to connect code expecting an io.Reader +// with code expecting an io.Writer. + +package pipe + +import ( + "io" + "os" + "sync" + "time" +) + +// onceError is an object that will only store an error once. +type onceError struct { + sync.Mutex // guards following + err error +} + +func (a *onceError) Store(err error) { + a.Lock() + defer a.Unlock() + if a.err != nil { + return + } + a.err = err +} +func (a *onceError) Load() error { + a.Lock() + defer a.Unlock() + return a.err +} + +// A pipe is the shared pipe structure underlying PipeReader and PipeWriter. +type pipe struct { + wrMu sync.Mutex // Serializes Write operations + wrCh chan []byte + rdCh chan int + + once sync.Once // Protects closing done + done chan struct{} + rerr onceError + werr onceError + + readDeadline PipeDeadline + writeDeadline PipeDeadline +} + +func (p *pipe) read(b []byte) (n int, err error) { + select { + case <-p.done: + return 0, p.readCloseError() + case <-p.readDeadline.Wait(): + return 0, os.ErrDeadlineExceeded + default: + } + + select { + case bw := <-p.wrCh: + nr := copy(b, bw) + p.rdCh <- nr + return nr, nil + case <-p.done: + return 0, p.readCloseError() + case <-p.readDeadline.Wait(): + return 0, os.ErrDeadlineExceeded + } +} + +func (p *pipe) closeRead(err error) error { + if err == nil { + err = io.ErrClosedPipe + } + p.rerr.Store(err) + p.once.Do(func() { close(p.done) }) + return nil +} + +func (p *pipe) write(b []byte) (n int, err error) { + select { + case <-p.done: + return 0, p.writeCloseError() + case <-p.writeDeadline.Wait(): + return 0, os.ErrDeadlineExceeded + default: + p.wrMu.Lock() + defer p.wrMu.Unlock() + } + + for once := true; once || len(b) > 0; once = false { + select { + case p.wrCh <- b: + nw := <-p.rdCh + b = b[nw:] + n += nw + case <-p.done: + return n, p.writeCloseError() + case <-p.writeDeadline.Wait(): + return n, os.ErrDeadlineExceeded + } + } + return n, nil +} + +func (p *pipe) closeWrite(err error) error { + if err == nil { + err = io.EOF + } + p.werr.Store(err) + p.once.Do(func() { close(p.done) }) + return nil +} + +// readCloseError is considered internal to the pipe type. +func (p *pipe) readCloseError() error { + rerr := p.rerr.Load() + if werr := p.werr.Load(); rerr == nil && werr != nil { + return werr + } + return io.ErrClosedPipe +} + +// writeCloseError is considered internal to the pipe type. +func (p *pipe) writeCloseError() error { + werr := p.werr.Load() + if rerr := p.rerr.Load(); werr == nil && rerr != nil { + return rerr + } + return io.ErrClosedPipe +} + +// A PipeReader is the read half of a pipe. +type PipeReader struct{ pipe } + +// Read implements the standard Read interface: +// it reads data from the pipe, blocking until a writer +// arrives or the write end is closed. +// If the write end is closed with an error, that error is +// returned as err; otherwise err is EOF. +func (r *PipeReader) Read(data []byte) (n int, err error) { + return r.pipe.read(data) +} + +// Close closes the reader; subsequent writes to the +// write half of the pipe will return the error [ErrClosedPipe]. +func (r *PipeReader) Close() error { + return r.CloseWithError(nil) +} + +// CloseWithError closes the reader; subsequent writes +// to the write half of the pipe will return the error err. +// +// CloseWithError never overwrites the previous error if it exists +// and always returns nil. +func (r *PipeReader) CloseWithError(err error) error { + return r.pipe.closeRead(err) +} + +// A PipeWriter is the write half of a pipe. +type PipeWriter struct{ r PipeReader } + +// Write implements the standard Write interface: +// it writes data to the pipe, blocking until one or more readers +// have consumed all the data or the read end is closed. +// If the read end is closed with an error, that err is +// returned as err; otherwise err is [ErrClosedPipe]. +func (w *PipeWriter) Write(data []byte) (n int, err error) { + return w.r.pipe.write(data) +} + +// Close closes the writer; subsequent reads from the +// read half of the pipe will return no bytes and EOF. +func (w *PipeWriter) Close() error { + return w.CloseWithError(nil) +} + +// CloseWithError closes the writer; subsequent reads from the +// read half of the pipe will return no bytes and the error err, +// or EOF if err is nil. +// +// CloseWithError never overwrites the previous error if it exists +// and always returns nil. +func (w *PipeWriter) CloseWithError(err error) error { + return w.r.pipe.closeWrite(err) +} + +// Pipe creates a synchronous in-memory pipe. +// It can be used to connect code expecting an [io.Reader] +// with code expecting an [io.Writer]. +// +// Reads and Writes on the pipe are matched one to one +// except when multiple Reads are needed to consume a single Write. +// That is, each Write to the [PipeWriter] blocks until it has satisfied +// one or more Reads from the [PipeReader] that fully consume +// the written data. +// The data is copied directly from the Write to the corresponding +// Read (or Reads); there is no internal buffering. +// +// It is safe to call Read and Write in parallel with each other or with Close. +// Parallel calls to Read and parallel calls to Write are also safe: +// the individual calls will be gated sequentially. +// +// Added SetReadDeadline and SetWriteDeadline methods based on `io.Pipe`. +func Pipe() (*PipeReader, *PipeWriter) { + pw := &PipeWriter{r: PipeReader{pipe: pipe{ + wrCh: make(chan []byte), + rdCh: make(chan int), + done: make(chan struct{}), + readDeadline: MakePipeDeadline(), + writeDeadline: MakePipeDeadline(), + }}} + return &pw.r, pw +} + +func (p *PipeReader) SetReadDeadline(t time.Time) error { + if isClosedChan(p.done) { + return io.ErrClosedPipe + } + p.readDeadline.Set(t) + return nil +} + +func (p *PipeWriter) SetWriteDeadline(t time.Time) error { + if isClosedChan(p.r.done) { + return io.ErrClosedPipe + } + p.r.writeDeadline.Set(t) + return nil +} diff --git a/transport/anytls/session/client.go b/transport/anytls/session/client.go index 869a4b31..2312a6ff 100644 --- a/transport/anytls/session/client.go +++ b/transport/anytls/session/client.go @@ -123,7 +123,7 @@ func (c *Client) createSession(ctx context.Context) (*Session, error) { c.idleSession.Remove(math.MaxUint64 - session.seq) c.idleSessionLock.Unlock() } - session.Run(false) + session.Run() return session, nil } diff --git a/transport/anytls/session/session.go b/transport/anytls/session/session.go index ce4d580f..963533ea 100644 --- a/transport/anytls/session/session.go +++ b/transport/anytls/session/session.go @@ -36,10 +36,11 @@ type Session struct { padding *atomic.TypedValue[*padding.PaddingFactory] // client - isClient bool - buffering bool - buffer []byte - pktCounter atomic.Uint32 + isClient bool + sendPadding bool + buffering bool + buffer []byte + pktCounter atomic.Uint32 // server onNewStream func(stream *Stream) @@ -47,9 +48,10 @@ type Session struct { func NewClientSession(conn net.Conn, _padding *atomic.TypedValue[*padding.PaddingFactory]) *Session { s := &Session{ - conn: conn, - isClient: true, - padding: _padding, + conn: conn, + isClient: true, + sendPadding: true, + padding: _padding, } s.die = make(chan struct{}) s.streams = make(map[uint32]*Stream) @@ -60,7 +62,6 @@ func NewServerSession(conn net.Conn, onNewStream func(stream *Stream), _padding s := &Session{ conn: conn, onNewStream: onNewStream, - isClient: false, padding: _padding, } s.die = make(chan struct{}) @@ -68,8 +69,8 @@ func NewServerSession(conn net.Conn, onNewStream func(stream *Stream), _padding return s } -func (s *Session) Run(isServer bool) { - if isServer { +func (s *Session) Run() { + if !s.isClient { s.recvLoop() return } @@ -257,19 +258,18 @@ func (s *Session) recvLoop() error { } case cmdUpdatePaddingScheme: if hdr.Length() > 0 { - buffer := pool.Get(int(hdr.Length())) - if _, err := io.ReadFull(s.conn, buffer); err != nil { - pool.Put(buffer) + // `rawScheme` Do not use buffer to prevent subsequent misuse + rawScheme := make([]byte, int(hdr.Length())) + if _, err := io.ReadFull(s.conn, rawScheme); err != nil { return err } if s.isClient { - if padding.UpdatePaddingScheme(buffer, s.padding) { - log.Infoln("[Update padding succeed] %x\n", md5.Sum(buffer)) + if padding.UpdatePaddingScheme(rawScheme, s.padding) { + log.Infoln("[Update padding succeed] %x\n", md5.Sum(rawScheme)) } else { - log.Warnln("[Update padding failed] %x\n", md5.Sum(buffer)) + log.Warnln("[Update padding failed] %x\n", md5.Sum(rawScheme)) } } - pool.Put(buffer) } default: // I don't know what command it is (can't have data) @@ -319,7 +319,7 @@ func (s *Session) writeConn(b []byte) (n int, err error) { } // calulate & send padding - if s.isClient { + if s.sendPadding { pkt := s.pktCounter.Add(1) paddingF := s.padding.Load() if pkt < paddingF.Stop { @@ -333,7 +333,6 @@ func (s *Session) writeConn(b []byte) (n int, err error) { continue } } - // logrus.Debugln(pkt, "write", l, "len", remainPayloadLen, "remain", remainPayloadLen-l) if remainPayloadLen > l { // this packet is all payload _, err = s.conn.Write(b[:l]) if err != nil { @@ -371,7 +370,12 @@ func (s *Session) writeConn(b []byte) (n int, err error) { // maybe still remain payload to write if len(b) == 0 { return + } else { + n2, err := s.conn.Write(b) + return n + n2, err } + } else { + s.sendPadding = false } } diff --git a/transport/anytls/session/stream.go b/transport/anytls/session/stream.go index 140396e4..9f21ff04 100644 --- a/transport/anytls/session/stream.go +++ b/transport/anytls/session/stream.go @@ -6,6 +6,8 @@ import ( "os" "sync" "time" + + "github.com/metacubex/mihomo/transport/anytls/pipe" ) // Stream implements net.Conn @@ -14,8 +16,9 @@ type Stream struct { sess *Session - pipeR *io.PipeReader - pipeW *io.PipeWriter + pipeR *pipe.PipeReader + pipeW *pipe.PipeWriter + writeDeadline pipe.PipeDeadline dieOnce sync.Once dieHook func() @@ -26,7 +29,8 @@ func newStream(id uint32, sess *Session) *Stream { s := new(Stream) s.id = id s.sess = sess - s.pipeR, s.pipeW = io.Pipe() + s.pipeR, s.pipeW = pipe.Pipe() + s.writeDeadline = pipe.MakePipeDeadline() return s } @@ -37,6 +41,11 @@ func (s *Stream) Read(b []byte) (n int, err error) { // Write implements net.Conn func (s *Stream) Write(b []byte) (n int, err error) { + select { + case <-s.writeDeadline.Wait(): + return 0, os.ErrDeadlineExceeded + default: + } f := newFrame(cmdPSH, s.id) f.data = b n, err = s.sess.writeFrame(f) @@ -67,15 +76,17 @@ func (s *Stream) sessionClose() (once bool) { } func (s *Stream) SetReadDeadline(t time.Time) error { - return os.ErrNotExist + return s.pipeR.SetReadDeadline(t) } func (s *Stream) SetWriteDeadline(t time.Time) error { - return os.ErrNotExist + s.writeDeadline.Set(t) + return nil } func (s *Stream) SetDeadline(t time.Time) error { - return os.ErrNotExist + s.SetWriteDeadline(t) + return s.SetReadDeadline(t) } // LocalAddr satisfies net.Conn interface