diff --git a/transport/anytls/session/client.go b/transport/anytls/session/client.go index 58adbf94..b0c9982b 100644 --- a/transport/anytls/session/client.go +++ b/transport/anytls/session/client.go @@ -19,7 +19,7 @@ type Client struct { die context.Context dieCancel context.CancelFunc - dialOut func(ctx context.Context) (net.Conn, error) + dialOut util.DialOutFunc sessionCounter atomic.Uint64 idleSession *skiplist.SkipList[uint64, *Session] @@ -31,7 +31,7 @@ type Client struct { minIdleSession int } -func NewClient(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error), _padding *atomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration, minIdleSession int) *Client { +func NewClient(ctx context.Context, dialOut util.DialOutFunc, _padding *atomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration, minIdleSession int) *Client { c := &Client{ dialOut: dialOut, padding: _padding, @@ -83,10 +83,16 @@ func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) { session.dieHook() } } else { - c.idleSessionLock.Lock() - session.idleSince = time.Now() - c.idleSession.Insert(math.MaxUint64-session.seq, session) - c.idleSessionLock.Unlock() + select { + case <-c.die.Done(): + // Now client has been closed + go session.Close() + default: + c.idleSessionLock.Lock() + session.idleSince = time.Now() + c.idleSession.Insert(math.MaxUint64-session.seq, session) + c.idleSessionLock.Unlock() + } } } @@ -131,7 +137,8 @@ func (c *Client) createSession(ctx context.Context) (*Session, error) { func (c *Client) Close() error { c.dieCancel() - go c.idleCleanupExpTime(time.Now()) + c.minIdleSession = 0 + go c.idleCleanupExpTime(time.Time{}) return nil } diff --git a/transport/anytls/util/type.go b/transport/anytls/util/type.go new file mode 100644 index 00000000..4f28d489 --- /dev/null +++ b/transport/anytls/util/type.go @@ -0,0 +1,8 @@ +package util + +import ( + "context" + "net" +) + +type DialOutFunc func(ctx context.Context) (net.Conn, error)