diff --git a/listener/anytls/server.go b/listener/anytls/server.go index f89f2277..667e8777 100644 --- a/listener/anytls/server.go +++ b/listener/anytls/server.go @@ -177,6 +177,6 @@ func (l *Listener) HandleConn(conn net.Conn, h *sing.ListenerHandler) { Destination: destination, }) }, &l.padding) - session.Run(true) + session.Run() session.Close() } diff --git a/transport/anytls/session/client.go b/transport/anytls/session/client.go index 869a4b31..2312a6ff 100644 --- a/transport/anytls/session/client.go +++ b/transport/anytls/session/client.go @@ -123,7 +123,7 @@ func (c *Client) createSession(ctx context.Context) (*Session, error) { c.idleSession.Remove(math.MaxUint64 - session.seq) c.idleSessionLock.Unlock() } - session.Run(false) + session.Run() return session, nil } diff --git a/transport/anytls/session/session.go b/transport/anytls/session/session.go index ce4d580f..946ae9f2 100644 --- a/transport/anytls/session/session.go +++ b/transport/anytls/session/session.go @@ -36,10 +36,11 @@ type Session struct { padding *atomic.TypedValue[*padding.PaddingFactory] // client - isClient bool - buffering bool - buffer []byte - pktCounter atomic.Uint32 + isClient bool + sendPadding bool + buffering bool + buffer []byte + pktCounter atomic.Uint32 // server onNewStream func(stream *Stream) @@ -47,9 +48,10 @@ type Session struct { func NewClientSession(conn net.Conn, _padding *atomic.TypedValue[*padding.PaddingFactory]) *Session { s := &Session{ - conn: conn, - isClient: true, - padding: _padding, + conn: conn, + isClient: true, + sendPadding: true, + padding: _padding, } s.die = make(chan struct{}) s.streams = make(map[uint32]*Stream) @@ -60,7 +62,6 @@ func NewServerSession(conn net.Conn, onNewStream func(stream *Stream), _padding s := &Session{ conn: conn, onNewStream: onNewStream, - isClient: false, padding: _padding, } s.die = make(chan struct{}) @@ -68,8 +69,8 @@ func NewServerSession(conn net.Conn, onNewStream func(stream *Stream), _padding return s } -func (s *Session) Run(isServer bool) { - if isServer { +func (s *Session) Run() { + if !s.isClient { s.recvLoop() return } @@ -319,7 +320,7 @@ func (s *Session) writeConn(b []byte) (n int, err error) { } // calulate & send padding - if s.isClient { + if s.sendPadding { pkt := s.pktCounter.Add(1) paddingF := s.padding.Load() if pkt < paddingF.Stop { @@ -333,7 +334,6 @@ func (s *Session) writeConn(b []byte) (n int, err error) { continue } } - // logrus.Debugln(pkt, "write", l, "len", remainPayloadLen, "remain", remainPayloadLen-l) if remainPayloadLen > l { // this packet is all payload _, err = s.conn.Write(b[:l]) if err != nil { @@ -371,7 +371,12 @@ func (s *Session) writeConn(b []byte) (n int, err error) { // maybe still remain payload to write if len(b) == 0 { return + } else { + n2, err := s.conn.Write(b) + return n + n2, err } + } else { + s.sendPadding = false } }